From 8607c2d42d10b0108de02528e813cc703e58813f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Jun 2023 02:14:12 -0400 Subject: [PATCH] Move latent scale factor from VAE to model. --- comfy/latent_formats.py | 16 ++++++++++++++++ comfy/model_base.py | 27 ++++++++++++++++++--------- comfy/samplers.py | 5 ++++- comfy/sd.py | 32 +++++++++++++++++++------------- comfy/supported_models.py | 13 +++++++------ comfy/supported_models_base.py | 7 ++++--- nodes.py | 6 +++++- 7 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 comfy/latent_formats.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py new file mode 100644 index 00000000..3e193828 --- /dev/null +++ b/comfy/latent_formats.py @@ -0,0 +1,16 @@ + +class LatentFormat: + def process_in(self, latent): + return latent * self.scale_factor + + def process_out(self, latent): + return latent / self.scale_factor + +class SD15(LatentFormat): + def __init__(self, scale_factor=0.18215): + self.scale_factor = scale_factor + +class SDXL(LatentFormat): + def __init__(self): + self.scale_factor = 0.13025 + diff --git a/comfy/model_base.py b/comfy/model_base.py index fa3c01c7..923c4348 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -6,9 +6,11 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np class BaseModel(torch.nn.Module): - def __init__(self, unet_config, v_prediction=False): + def __init__(self, model_config, v_prediction=False): super().__init__() + unet_config = model_config.unet_config + self.latent_format = model_config.latent_format self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) self.v_prediction = v_prediction @@ -75,9 +77,16 @@ class BaseModel(torch.nn.Module): del to_load return self + def process_latent_in(self, latent): + return self.latent_format.process_in(latent) + + def process_latent_out(self, latent): + return self.latent_format.process_out(latent) + + class SD21UNCLIP(BaseModel): - def __init__(self, unet_config, noise_aug_config, v_prediction=True): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, noise_aug_config, v_prediction=True): + super().__init__(model_config, v_prediction) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) def encode_adm(self, **kwargs): @@ -112,13 +121,13 @@ class SD21UNCLIP(BaseModel): return adm_out class SDInpaint(BaseModel): - def __init__(self, unet_config, v_prediction=False): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, v_prediction=False): + super().__init__(model_config, v_prediction) self.concat_keys = ("mask", "masked_image") class SDXLRefiner(BaseModel): - def __init__(self, unet_config, v_prediction=False): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, v_prediction=False): + super().__init__(model_config, v_prediction) self.embedder = Timestep(256) def encode_adm(self, **kwargs): @@ -144,8 +153,8 @@ class SDXLRefiner(BaseModel): return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): - def __init__(self, unet_config, v_prediction=False): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, v_prediction=False): + super().__init__(model_config, v_prediction) self.embedder = Timestep(256) def encode_adm(self, **kwargs): diff --git a/comfy/samplers.py b/comfy/samplers.py index 102bf925..d6a8f609 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -586,6 +586,9 @@ class KSampler: positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive") negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative") + if latent_image is not None: + latent_image = self.model.process_latent_in(latent_image) + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None @@ -672,4 +675,4 @@ class KSampler: else: samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - return samples.to(torch.float32) + return self.model.process_latent_out(samples.to(torch.float32)) diff --git a/comfy/sd.py b/comfy/sd.py index 15caf360..ead2c067 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -536,7 +536,7 @@ class CLIP: class VAE: - def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): + def __init__(self, ckpt_path=None, device=None, config=None): if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -550,7 +550,6 @@ class VAE: sd = diffusers_convert.convert_vae_state_dict(sd) self.first_stage_model.load_state_dict(sd, strict=False) - self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() self.device = device @@ -561,7 +560,7 @@ class VAE: steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) + decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0) output = torch.clamp(( (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + @@ -575,7 +574,7 @@ class VAE: steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor + encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -593,7 +592,7 @@ class VAE: pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -620,7 +619,7 @@ class VAE: samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") @@ -958,6 +957,7 @@ def load_gligen(ckpt_path): return model def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + #TODO: this function is a mess and should be removed eventually if config is None: with open(config_path, 'r') as stream: config = yaml.safe_load(stream) @@ -992,12 +992,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if state_dict is None: state_dict = utils.load_torch_file(ckpt_path) + class EmptyClass: + pass + + model_config = EmptyClass() + model_config.unet_config = unet_config + from . import latent_formats + model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) + if config['model']["target"].endswith("LatentInpaintDiffusion"): - model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + model = model_base.SDInpaint(model_config, v_prediction=v_prediction) elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): - model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction) else: - model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + model = model_base.BaseModel(model_config, v_prediction=v_prediction) if fp16: model = model.half() @@ -1006,14 +1014,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if output_vae: w = WeightsLoader() - vae = VAE(scale_factor=scale_factor, config=vae_config) + vae = VAE(config=vae_config) w.first_stage_model = vae.first_stage_model load_model_weights(w, state_dict) if output_clip: w = WeightsLoader() - class EmptyClass: - pass clip_target = EmptyClass() clip_target.params = clip_config.get("params", {}) if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): @@ -1055,7 +1061,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae = VAE(scale_factor=model_config.vae_scale_factor) + vae = VAE() w = WeightsLoader() w.first_stage_model = vae.first_stage_model load_model_weights(w, sd) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3120d501..51da9456 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sd2_clip from . import sdxl_clip from . import supported_models_base +from . import latent_formats class SD15(supported_models_base.BASE): unet_config = { @@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE): "num_head_channels": -1, } - vae_scale_factor = 0.18215 + latent_format = latent_formats.SD15 def process_clip_state_dict(self, state_dict): k = list(state_dict.keys()) @@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE): "adm_in_channels": None, } - vae_scale_factor = 0.18215 + latent_format = latent_formats.SD15 def v_prediction(self, state_dict): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction @@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE): "transformer_depth": [0, 4, 4, 0], } - vae_scale_factor = 0.13025 + latent_format = latent_formats.SDXL def get_model(self, state_dict): - return model_base.SDXLRefiner(self.unet_config) + return model_base.SDXLRefiner(self) def process_clip_state_dict(self, state_dict): keys_to_replace = {} @@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE): "adm_in_channels": 2816 } - vae_scale_factor = 0.13025 + latent_format = latent_formats.SDXL def get_model(self, state_dict): - return model_base.SDXL(self.unet_config) + return model_base.SDXL(self) def process_clip_state_dict(self, state_dict): keys_to_replace = {} diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 401e05d3..3312a99d 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,16 +49,17 @@ class BASE: def __init__(self, unet_config): self.unet_config = unet_config + self.latent_format = self.latent_format() for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] def get_model(self, state_dict): if self.inpaint_model(): - return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict)) + return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict)) elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) + return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) else: - return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict)) + return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict)) def process_clip_state_dict(self, state_dict): return state_dict diff --git a/nodes.py b/nodes.py index cb057a9f..b7447d58 100644 --- a/nodes.py +++ b/nodes.py @@ -284,6 +284,7 @@ class SaveLatent: output = {} output["latent_tensor"] = samples["samples"] + output["latent_format_version_0"] = torch.tensor([]) safetensors.torch.save_file(output, file, metadata=metadata) @@ -305,7 +306,10 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"].float()} + multiplier = 1.0 + if "latent_format_version_0" not in latent: + multiplier = 1.0 / 0.18215 + samples = {"samples": latent["latent_tensor"].float() * multiplier} return (samples, ) @classmethod