Improvements to the TAESD3 implementation.

This commit is contained in:
comfyanonymous 2024-06-16 02:04:24 -04:00
parent df7db0e027
commit 04e8798c37
3 changed files with 10 additions and 9 deletions

View File

@ -166,7 +166,7 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, latent_channels=4):
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@ -174,7 +174,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = latent_channels
self.latent_channels = 4
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@ -189,7 +189,7 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=sd["taesd_decoder.1.weight"].shape[1])
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
self.downscale_ratio = 4

View File

@ -54,6 +54,7 @@ class TAESD(nn.Module):
self.taesd_encoder = Encoder(latent_channels=latent_channels)
self.taesd_decoder = Decoder(latent_channels=latent_channels)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
@ -70,9 +71,9 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift

View File

@ -676,10 +676,13 @@ class VAELoader:
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
sd["vae_shift"] = torch.tensor(0.0)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
sd["vae_shift"] = torch.tensor(0.0)
elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305)
sd["vae_shift"] = torch.tensor(0.0609)
return sd
@classmethod
@ -697,10 +700,7 @@ class VAELoader:
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
latent_channels = 16 if vae_name == 'taesd3' else 4
vae = comfy.sd.VAE(sd=sd, latent_channels=latent_channels)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
class ControlNetLoader: