mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
support TAESD3 (#3738)
This commit is contained in:
parent
bb1969cab7
commit
df7db0e027
@ -129,6 +129,7 @@ class SD3(LatentFormat):
|
|||||||
[-0.0749, -0.0634, -0.0456],
|
[-0.0749, -0.0634, -0.0456],
|
||||||
[-0.1418, -0.1457, -0.1259]
|
[-0.1418, -0.1457, -0.1259]
|
||||||
]
|
]
|
||||||
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return (latent - self.shift_factor) * self.scale_factor
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
@ -166,7 +166,7 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, latent_channels=4):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
@ -174,7 +174,7 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = latent_channels
|
||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@ -189,7 +189,7 @@ class VAE:
|
|||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||||
elif "taesd_decoder.1.weight" in sd:
|
elif "taesd_decoder.1.weight" in sd:
|
||||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||||
self.first_stage_model = StageA()
|
self.first_stage_model = StageA()
|
||||||
self.downscale_ratio = 4
|
self.downscale_ratio = 4
|
||||||
|
@ -25,18 +25,19 @@ class Block(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.fuse(self.conv(x) + self.skip(x))
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|
||||||
def Encoder():
|
def Encoder(latent_channels=4):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
conv(3, 64), Block(64, 64),
|
conv(3, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 4),
|
conv(64, latent_channels),
|
||||||
)
|
)
|
||||||
|
|
||||||
def Decoder():
|
|
||||||
|
def Decoder(latent_channels=4):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
Clamp(), conv(4, 64), nn.ReLU(),
|
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
@ -47,11 +48,11 @@ class TAESD(nn.Module):
|
|||||||
latent_magnitude = 3
|
latent_magnitude = 3
|
||||||
latent_shift = 0.5
|
latent_shift = 0.5
|
||||||
|
|
||||||
def __init__(self, encoder_path=None, decoder_path=None):
|
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.taesd_encoder = Encoder()
|
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
||||||
self.taesd_decoder = Decoder()
|
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
||||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||||
|
@ -64,7 +64,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if method == LatentPreviewMethod.TAESD:
|
if method == LatentPreviewMethod.TAESD:
|
||||||
if taesd_decoder_path:
|
if taesd_decoder_path:
|
||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
17
nodes.py
17
nodes.py
@ -634,6 +634,8 @@ class VAELoader:
|
|||||||
sdxl_taesd_dec = False
|
sdxl_taesd_dec = False
|
||||||
sd1_taesd_enc = False
|
sd1_taesd_enc = False
|
||||||
sd1_taesd_dec = False
|
sd1_taesd_dec = False
|
||||||
|
sd3_taesd_enc = False
|
||||||
|
sd3_taesd_dec = False
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
if v.startswith("taesd_decoder."):
|
||||||
@ -644,10 +646,16 @@ class VAELoader:
|
|||||||
sdxl_taesd_dec = True
|
sdxl_taesd_dec = True
|
||||||
elif v.startswith("taesdxl_encoder."):
|
elif v.startswith("taesdxl_encoder."):
|
||||||
sdxl_taesd_enc = True
|
sdxl_taesd_enc = True
|
||||||
|
elif v.startswith("taesd3_decoder."):
|
||||||
|
sd3_taesd_dec = True
|
||||||
|
elif v.startswith("taesd3_encoder."):
|
||||||
|
sd3_taesd_enc = True
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
vaes.append("taesd")
|
vaes.append("taesd")
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
vaes.append("taesdxl")
|
vaes.append("taesdxl")
|
||||||
|
if sd3_taesd_dec and sd3_taesd_enc:
|
||||||
|
vaes.append("taesd3")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -670,6 +678,8 @@ class VAELoader:
|
|||||||
sd["vae_scale"] = torch.tensor(0.18215)
|
sd["vae_scale"] = torch.tensor(0.18215)
|
||||||
elif name == "taesdxl":
|
elif name == "taesdxl":
|
||||||
sd["vae_scale"] = torch.tensor(0.13025)
|
sd["vae_scale"] = torch.tensor(0.13025)
|
||||||
|
elif name == "taesd3":
|
||||||
|
sd["vae_scale"] = torch.tensor(1.5305)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -682,12 +692,15 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
|
||||||
|
latent_channels = 16 if vae_name == 'taesd3' else 4
|
||||||
|
|
||||||
|
vae = comfy.sd.VAE(sd=sd, latent_channels=latent_channels)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
|
Loading…
Reference in New Issue
Block a user