diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index efe20a3b..d1ecd81f 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -81,7 +81,7 @@ class DDPM(torch.nn.Module): super().__init__() assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization - print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t @@ -522,8 +522,8 @@ class LatentDiffusion(DDPM): """main class""" def __init__(self, - first_stage_config, - cond_stage_config, + first_stage_config={}, + cond_stage_config={}, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, @@ -562,8 +562,6 @@ class LatentDiffusion(DDPM): # self.instantiate_first_stage(first_stage_config) # self.instantiate_cond_stage(cond_stage_config) - self.first_stage_config = first_stage_config - self.cond_stage_config = cond_stage_config self.cond_stage_forward = cond_stage_forward self.clip_denoised = False diff --git a/comfy/sd.py b/comfy/sd.py index 4ac11eba..0d7517c4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -317,9 +317,7 @@ class VAE: pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples - def decode_tiled(self, samples): - tile_x = tile_y = 64 - overlap = 8 + def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8): model_management.unload_model() output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu") self.first_stage_model = self.first_stage_model.to(self.device) @@ -656,3 +654,103 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e sd = load_torch_file(ckpt_path) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) return (ModelPatcher(model), clip, vae) + + +def load_checkpoint_guess_config(ckpt_path, fp16=False, output_vae=True, output_clip=True, embedding_directory=None): + sd = load_torch_file(ckpt_path) + sd_keys = sd.keys() + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE() + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip_config = {} + if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: + clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + else: + clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + sd_config = { + "linear_start": 0.00085, + "linear_end": 0.012, + "num_timesteps_cond": 1, + "log_every_t": 200, + "timesteps": 1000, + "first_stage_key": "jpg", + "cond_stage_key": "txt", + "image_size": 64, + "channels": 4, + "cond_stage_trainable": False, + "monitor": "val/loss_simple_ema", + "scale_factor": 0.18215, + "use_ema": False, + } + + unet_config = { + "use_checkpoint": True, + "image_size": 32, + "out_channels": 4, + "attention_resolutions": [ + 4, + 2, + 1 + ], + "num_res_blocks": 2, + "channel_mult": [ + 1, + 2, + 4, + 4 + ], + "use_spatial_transformer": True, + "transformer_depth": 1, + "legacy": False + } + + if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: + unet_config['use_linear_in_transformer'] = True + + unet_config["use_fp16"] = fp16 + unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] + unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] + unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] + + sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} + model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + + if unet_config["in_channels"] > 4: #inpainting model + sd_config["conditioning_key"] = "hybrid" + sd_config["finetune_keys"] = None + model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + else: + sd_config["conditioning_key"] = "crossattn" + + if unet_config["context_dim"] == 1024: + unet_config["num_head_channels"] = 64 #SD2.x + else: + unet_config["num_heads"] = 8 #SD1.x + + + model = instantiate_from_config(model_config) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction + cond = torch.zeros((1, 2, unet_config["context_dim"]), device="cpu") + x_in = torch.rand((1, unet_config["in_channels"], 8, 8), device="cpu", generator=torch.manual_seed(1)) + out = model.apply_model(x_in, torch.tensor([999], device="cpu"), cond) + if out.mean() < -0.6: #mean of eps should be ~0 and mean of v prediction should be ~-1 + model.parameterization = 'v' + + return (ModelPatcher(model), clip, vae) diff --git a/nodes.py b/nodes.py index 02cb7e8e..7e63b72c 100644 --- a/nodes.py +++ b/nodes.py @@ -202,6 +202,28 @@ class CheckpointLoader: ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) +class CheckpointLoaderSimple: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + ckpt_dir = os.path.join(models_dir, "checkpoints") + + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ), + "type": (["fp16", "fp32"],), + "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "_for_testing" + + def load_checkpoint(self, ckpt_name, type, stop_at_clip_layer, output_vae=True, output_clip=True): + ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, type=="fp16", output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) + if out[1] is not None: + out[1].clip_layer(stop_at_clip_layer) + return out + class LoraLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") lora_dir = os.path.join(models_dir, "loras") @@ -837,6 +859,7 @@ NODE_CLASS_MAPPINGS = { "DiffControlNetLoader": DiffControlNetLoader, "T2IAdapterLoader": T2IAdapterLoader, "VAEDecodeTiled": VAEDecodeTiled, + "CheckpointLoaderSimple": CheckpointLoaderSimple, } CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")