add the new flux version
Ideally there would be an 'is_flux' bool to check; using `is not shared.sd_model.is_webui_legacy_model():` instead.
This commit is contained in:
+19
-6
@@ -63,7 +63,12 @@ class TAESDDecoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if latent_channels is None:
|
if latent_channels is None:
|
||||||
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
|
if "taesd3" in str(decoder_path):
|
||||||
|
latent_channels = 16
|
||||||
|
elif "taef1" in str(decoder_path):
|
||||||
|
latent_channels = 16
|
||||||
|
else:
|
||||||
|
latent_channels = 4
|
||||||
|
|
||||||
self.decoder = decoder(latent_channels)
|
self.decoder = decoder(latent_channels)
|
||||||
self.decoder.load_state_dict(
|
self.decoder.load_state_dict(
|
||||||
@@ -79,7 +84,12 @@ class TAESDEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if latent_channels is None:
|
if latent_channels is None:
|
||||||
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
|
if "taesd3" in str(encoder_path):
|
||||||
|
latent_channels = 16
|
||||||
|
elif "taef1" in str(encoder_path):
|
||||||
|
latent_channels = 16
|
||||||
|
else:
|
||||||
|
latent_channels = 4
|
||||||
|
|
||||||
self.encoder = encoder(latent_channels)
|
self.encoder = encoder(latent_channels)
|
||||||
self.encoder.load_state_dict(
|
self.encoder.load_state_dict(
|
||||||
@@ -95,15 +105,16 @@ def download_model(model_path, model_url):
|
|||||||
|
|
||||||
|
|
||||||
def decoder_model():
|
def decoder_model():
|
||||||
if not shared.sd_model.is_webui_legacy_model():
|
|
||||||
return None
|
|
||||||
|
|
||||||
if shared.sd_model.is_sd3:
|
if shared.sd_model.is_sd3:
|
||||||
model_name = "taesd3_decoder.pth"
|
model_name = "taesd3_decoder.pth"
|
||||||
|
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
|
||||||
|
model_name = "taef1_decoder.pth"
|
||||||
elif shared.sd_model.is_sdxl:
|
elif shared.sd_model.is_sdxl:
|
||||||
model_name = "taesdxl_decoder.pth"
|
model_name = "taesdxl_decoder.pth"
|
||||||
else:
|
elif shared.sd_model.is_sd1:
|
||||||
model_name = "taesd_decoder.pth"
|
model_name = "taesd_decoder.pth"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||||
|
|
||||||
@@ -125,6 +136,8 @@ def decoder_model():
|
|||||||
def encoder_model():
|
def encoder_model():
|
||||||
if shared.sd_model.is_sd3:
|
if shared.sd_model.is_sd3:
|
||||||
model_name = "taesd3_encoder.pth"
|
model_name = "taesd3_encoder.pth"
|
||||||
|
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
|
||||||
|
model_name = "taef1_encoder.pth"
|
||||||
elif shared.sd_model.is_sdxl:
|
elif shared.sd_model.is_sdxl:
|
||||||
model_name = "taesdxl_encoder.pth"
|
model_name = "taesdxl_encoder.pth"
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user