diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 92f39d5c..4b4a9eda 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -129,6 +129,7 @@ class SD3(LatentFormat): [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259] ] + self.taesd_decoder_name = "taesd3_decoder" def process_in(self, latent): return (latent - self.shift_factor) * self.scale_factor diff --git a/comfy/sd.py b/comfy/sd.py index f1e48713..d2720ec1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -166,7 +166,7 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None): + def __init__(self, sd=None, device=None, config=None, dtype=None, latent_channels=4): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -174,7 +174,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 self.upscale_ratio = 8 - self.latent_channels = 4 + self.latent_channels = latent_channels self.output_channels = 3 self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) @@ -189,7 +189,7 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: - self.first_stage_model = comfy.taesd.taesd.TAESD() + self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() self.downscale_ratio = 4 diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 8f96c54e..74031c60 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -25,18 +25,19 @@ class Block(nn.Module): def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) -def Encoder(): +def Encoder(latent_channels=4): return nn.Sequential( conv(3, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 4), + conv(64, latent_channels), ) -def Decoder(): + +def Decoder(latent_channels=4): return nn.Sequential( - Clamp(), conv(4, 64), nn.ReLU(), + Clamp(), conv(latent_channels, 64), nn.ReLU(), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), @@ -47,11 +48,11 @@ class TAESD(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path=None, decoder_path=None): + def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.taesd_encoder = Encoder() - self.taesd_decoder = Decoder() + self.taesd_encoder = Encoder(latent_channels=latent_channels) + self.taesd_decoder = Decoder(latent_channels=latent_channels) self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) if encoder_path is not None: self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) diff --git a/latent_preview.py b/latent_preview.py index 54aa233f..ae6c106e 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -64,7 +64,7 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: - taesd = TAESD(None, taesd_decoder_path).to(device) + taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) previewer = TAESDPreviewerImpl(taesd) else: logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) diff --git a/nodes.py b/nodes.py index 0b2a96f7..ca10ca32 100644 --- a/nodes.py +++ b/nodes.py @@ -634,6 +634,8 @@ class VAELoader: sdxl_taesd_dec = False sd1_taesd_enc = False sd1_taesd_dec = False + sd3_taesd_enc = False + sd3_taesd_dec = False for v in approx_vaes: if v.startswith("taesd_decoder."): @@ -644,10 +646,16 @@ class VAELoader: sdxl_taesd_dec = True elif v.startswith("taesdxl_encoder."): sdxl_taesd_enc = True + elif v.startswith("taesd3_decoder."): + sd3_taesd_dec = True + elif v.startswith("taesd3_encoder."): + sd3_taesd_enc = True if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: vaes.append("taesdxl") + if sd3_taesd_dec and sd3_taesd_enc: + vaes.append("taesd3") return vaes @staticmethod @@ -670,6 +678,8 @@ class VAELoader: sd["vae_scale"] = torch.tensor(0.18215) elif name == "taesdxl": sd["vae_scale"] = torch.tensor(0.13025) + elif name == "taesd3": + sd["vae_scale"] = torch.tensor(1.5305) return sd @classmethod @@ -682,12 +692,15 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name in ["taesd", "taesdxl"]: + if vae_name in ["taesd", "taesdxl", "taesd3"]: sd = self.load_taesd(vae_name) else: vae_path = folder_paths.get_full_path("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) - vae = comfy.sd.VAE(sd=sd) + + latent_channels = 16 if vae_name == 'taesd3' else 4 + + vae = comfy.sd.VAE(sd=sd, latent_channels=latent_channels) return (vae,) class ControlNetLoader: