mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add taesd and taesdxl to VAELoader node.
They will show up if both the taesd_encoder and taesd_decoder or taesdxl model files are present in the models/vae_approx directory.
This commit is contained in:
parent
6ff06fa796
commit
cd4fc77d5f
17
comfy/sd.py
17
comfy/sd.py
@ -23,6 +23,7 @@ import comfy.model_patcher
|
|||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
def load_model_weights(model, sd):
|
def load_model_weights(model, sd):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -154,10 +155,16 @@ class VAE:
|
|||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking
|
||||||
|
self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
if "taesd_decoder.1.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
else:
|
||||||
|
#default SD1.x/SD2.x VAE parameters
|
||||||
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||||
else:
|
else:
|
||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
@ -206,7 +213,7 @@ class VAE:
|
|||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
try:
|
try:
|
||||||
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
|
memory_used = self.memory_used_decode(samples_in.shape)
|
||||||
model_management.free_memory(memory_used, self.device)
|
model_management.free_memory(memory_used, self.device)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
@ -234,7 +241,7 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
try:
|
try:
|
||||||
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||||
model_management.free_memory(memory_used, self.device)
|
model_management.free_memory(memory_used, self.device)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
|
@ -46,15 +46,16 @@ class TAESD(nn.Module):
|
|||||||
latent_magnitude = 3
|
latent_magnitude = 3
|
||||||
latent_shift = 0.5
|
latent_shift = 0.5
|
||||||
|
|
||||||
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
def __init__(self, encoder_path=None, decoder_path=None):
|
||||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = Encoder()
|
self.taesd_encoder = Encoder()
|
||||||
self.decoder = Decoder()
|
self.taesd_decoder = Decoder()
|
||||||
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scale_latents(x):
|
def scale_latents(x):
|
||||||
@ -65,3 +66,11 @@ class TAESD(nn.Module):
|
|||||||
def unscale_latents(x):
|
def unscale_latents(x):
|
||||||
"""[0, 1] -> raw latents"""
|
"""[0, 1] -> raw latents"""
|
||||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
x_sample = self.taesd_decoder(x * self.vae_scale)
|
||||||
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
|
return x_sample
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
||||||
|
@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
self.taesd = taesd
|
self.taesd = taesd
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
x_sample = self.taesd.decoder(x0[:1])[0].detach()
|
x_sample = self.taesd.decode(x0[:1])[0].detach()
|
||||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
|
||||||
x_sample = x_sample.sub(0.5).mul(2)
|
|
||||||
|
|
||||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
55
nodes.py
55
nodes.py
@ -573,9 +573,55 @@ class LoraLoader:
|
|||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
|
@staticmethod
|
||||||
|
def vae_list():
|
||||||
|
vaes = folder_paths.get_filename_list("vae")
|
||||||
|
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||||
|
sdxl_taesd_enc = False
|
||||||
|
sdxl_taesd_dec = False
|
||||||
|
sd1_taesd_enc = False
|
||||||
|
sd1_taesd_dec = False
|
||||||
|
|
||||||
|
for v in approx_vaes:
|
||||||
|
if v.startswith("taesd_decoder."):
|
||||||
|
sd1_taesd_dec = True
|
||||||
|
elif v.startswith("taesd_encoder."):
|
||||||
|
sd1_taesd_enc = True
|
||||||
|
elif v.startswith("taesdxl_decoder."):
|
||||||
|
sdxl_taesd_dec = True
|
||||||
|
elif v.startswith("taesdxl_encoder."):
|
||||||
|
sdxl_taesd_enc = True
|
||||||
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
|
vaes.append("taesd")
|
||||||
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
|
vaes.append("taesdxl")
|
||||||
|
return vaes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_taesd(name):
|
||||||
|
sd = {}
|
||||||
|
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||||
|
|
||||||
|
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||||
|
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||||
|
|
||||||
|
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
||||||
|
for k in enc:
|
||||||
|
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||||
|
|
||||||
|
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
||||||
|
for k in dec:
|
||||||
|
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||||
|
|
||||||
|
if name == "taesd":
|
||||||
|
sd["vae_scale"] = torch.tensor(0.18215)
|
||||||
|
elif name == "taesdxl":
|
||||||
|
sd["vae_scale"] = torch.tensor(0.13025)
|
||||||
|
return sd
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
|
return {"required": { "vae_name": (s.vae_list(), )}}
|
||||||
RETURN_TYPES = ("VAE",)
|
RETURN_TYPES = ("VAE",)
|
||||||
FUNCTION = "load_vae"
|
FUNCTION = "load_vae"
|
||||||
|
|
||||||
@ -583,8 +629,11 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
if vae_name in ["taesd", "taesdxl"]:
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = self.load_taesd(vae_name)
|
||||||
|
else:
|
||||||
|
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||||
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user