diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index efe20a3b..802034c7 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -522,8 +522,8 @@ class LatentDiffusion(DDPM): """main class""" def __init__(self, - first_stage_config, - cond_stage_config, + first_stage_config={}, + cond_stage_config={}, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, @@ -562,8 +562,6 @@ class LatentDiffusion(DDPM): # self.instantiate_first_stage(first_stage_config) # self.instantiate_cond_stage(cond_stage_config) - self.first_stage_config = first_stage_config - self.cond_stage_config = cond_stage_config self.cond_stage_forward = cond_stage_forward self.clip_denoised = False diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 00a20782..59683f64 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -9,6 +9,8 @@ from typing import Optional, Any from ldm.modules.diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention +import model_management + try: import xformers import xformers.ops @@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module): _, _, k_tokens = key_t.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - stats = torch.cuda.memory_stats(query.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) + chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD kv_chunk_size_min = None @@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = model_management.get_free_memory(q.device) gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() diff --git a/comfy/model_management.py b/comfy/model_management.py index 36f925c4..32159b82 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -145,14 +145,25 @@ def unload_if_low_vram(model): return model -def get_free_memory(): - dev = torch.cuda.current_device() - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - return mem_free_cuda + mem_free_torch +def get_free_memory(dev=None, torch_free_too=False): + if dev is None: + dev = torch.cuda.current_device() + + if hasattr(dev, 'type') and dev.type == 'cpu': + mem_free_total = psutil.virtual_memory().available + mem_free_torch = mem_free_total + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + if torch_free_too: + return (mem_free_total, mem_free_torch) + else: + return mem_free_total def maximum_batch_area(): global vram_state @@ -162,6 +173,30 @@ def maximum_batch_area(): memory_free = get_free_memory() / (1024 * 1024) area = ((memory_free - 1024) * 0.9) / (0.6) return int(max(area, 0)) + +def cpu_mode(): + global vram_state + return vram_state == CPU + +def should_use_fp16(): + if cpu_mode(): + return False #TODO ? + + if torch.cuda.is_bf16_supported(): + return True + + props = torch.cuda.get_device_properties("cuda") + if props.major < 7: + return False + + #FP32 is faster on those cards? + nvidia_16_series = ["1660", "1650", "1630"] + for x in nvidia_16_series: + if x in props.name: + return False + + return True + #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/sd.py b/comfy/sd.py index 4ac11eba..50d81f77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -266,6 +266,7 @@ class CLIP: self.cond_stage_model = clip(**(params)) self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = ModelPatcher(self.cond_stage_model) + self.layer_idx = -1 def clone(self): n = CLIP(no_init=True) @@ -273,6 +274,7 @@ class CLIP: n.patcher = self.patcher.clone() n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer + n.layer_idx = self.layer_idx return n def load_from_state_dict(self, sd): @@ -282,9 +284,10 @@ class CLIP: return self.patcher.add_patches(patches, strength) def clip_layer(self, layer_idx): - return self.cond_stage_model.clip_layer(layer_idx) + self.layer_idx = layer_idx def encode(self, text): + self.cond_stage_model.clip_layer(self.layer_idx) tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() @@ -317,9 +320,7 @@ class VAE: pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples - def decode_tiled(self, samples): - tile_x = tile_y = 64 - overlap = 8 + def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8): model_management.unload_model() output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu") self.first_stage_model = self.first_stage_model.to(self.device) @@ -656,3 +657,103 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e sd = load_torch_file(ckpt_path) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) return (ModelPatcher(model), clip, vae) + + +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): + sd = load_torch_file(ckpt_path) + sd_keys = sd.keys() + clip = None + vae = None + + fp16 = model_management.should_use_fp16() + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE() + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip_config = {} + if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: + clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + else: + clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + sd_config = { + "linear_start": 0.00085, + "linear_end": 0.012, + "num_timesteps_cond": 1, + "log_every_t": 200, + "timesteps": 1000, + "first_stage_key": "jpg", + "cond_stage_key": "txt", + "image_size": 64, + "channels": 4, + "cond_stage_trainable": False, + "monitor": "val/loss_simple_ema", + "scale_factor": 0.18215, + "use_ema": False, + } + + unet_config = { + "use_checkpoint": True, + "image_size": 32, + "out_channels": 4, + "attention_resolutions": [ + 4, + 2, + 1 + ], + "num_res_blocks": 2, + "channel_mult": [ + 1, + 2, + 4, + 4 + ], + "use_spatial_transformer": True, + "transformer_depth": 1, + "legacy": False + } + + if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: + unet_config['use_linear_in_transformer'] = True + + unet_config["use_fp16"] = fp16 + unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] + unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] + unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] + + sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} + model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + + if unet_config["in_channels"] > 4: #inpainting model + sd_config["conditioning_key"] = "hybrid" + sd_config["finetune_keys"] = None + model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + else: + sd_config["conditioning_key"] = "crossattn" + + if unet_config["context_dim"] == 1024: + unet_config["num_head_channels"] = 64 #SD2.x + else: + unet_config["num_heads"] = 8 #SD1.x + + if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction + k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" + out = sd[k] + if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + sd_config["parameterization"] = 'v' + + model = instantiate_from_config(model_config) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + return (ModelPatcher(model), clip, vae) diff --git a/nodes.py b/nodes.py index dd9f06d4..f4f07bb3 100644 --- a/nodes.py +++ b/nodes.py @@ -202,6 +202,40 @@ class CheckpointLoader: ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) +class CheckpointLoaderSimple: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + ckpt_dir = os.path.join(models_dir, "checkpoints") + + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) + return out + +class CLIPSetLastLayer: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip": ("CLIP", ), + "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "set_last_layer" + + CATEGORY = "conditioning" + + def set_last_layer(self, clip, stop_at_clip_layer): + clip = clip.clone() + clip.clip_layer(stop_at_clip_layer) + return (clip,) + class LoraLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") lora_dir = os.path.join(models_dir, "loras") @@ -325,17 +359,15 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), - "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "loaders" - def load_clip(self, clip_name, stop_at_clip_layer): + def load_clip(self, clip_name): clip_path = os.path.join(self.clip_dir, clip_name) clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) - clip.clip_layer(stop_at_clip_layer) return (clip,) class EmptyLatentImage: @@ -810,7 +842,9 @@ class ImageInvert: NODE_CLASS_MAPPINGS = { "KSampler": KSampler, "CheckpointLoader": CheckpointLoader, + "CheckpointLoaderSimple": CheckpointLoaderSimple, "CLIPTextEncode": CLIPTextEncode, + "CLIPSetLastLayer": CLIPSetLastLayer, "VAEDecode": VAEDecode, "VAEEncode": VAEEncode, "VAEEncodeForInpaint": VAEEncodeForInpaint,