From f92074b84f44af6127e8f4ed1245d6c04dc0b750 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 28 Aug 2023 14:49:18 -0400 Subject: [PATCH] Move ModelPatcher to model_patcher.py --- comfy/controlnet.py | 4 +- comfy/model_patcher.py | 270 +++++++++++++++++++++++++++++++++++++++ comfy/sd.py | 278 +---------------------------------------- 3 files changed, 278 insertions(+), 274 deletions(-) create mode 100644 comfy/model_patcher.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 4e20100c..b18ab967 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,9 +1,9 @@ import torch import math import comfy.utils -import comfy.sd import comfy.model_management import comfy.model_detection +import comfy.model_patcher import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -129,7 +129,7 @@ class ControlNet(ControlBase): def __init__(self, control_model, global_average_pooling=False, device=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py new file mode 100644 index 00000000..2f087a60 --- /dev/null +++ b/comfy/model_patcher.py @@ -0,0 +1,270 @@ +import torch +import copy +import inspect + +import comfy.utils + +class ModelPatcher: + def __init__(self, model, load_device, offload_device, size=0, current_device=None): + self.size = size + self.model = model + self.patches = {} + self.backup = {} + self.model_options = {"transformer_options":{}} + self.model_size() + self.load_device = load_device + self.offload_device = offload_device + if current_device is None: + self.current_device = self.offload_device + else: + self.current_device = current_device + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + self.model_keys = set(model_sd.keys()) + return size + + def clone(self): + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + + n.model_options = copy.deepcopy(self.model_options) + n.model_keys = self.model_keys + return n + + def is_clone(self, other): + if hasattr(other, 'model') and self.model is other.model: + return True + return False + + def set_model_sampler_cfg_function(self, sampler_cfg_function): + if len(inspect.signature(sampler_cfg_function).parameters) == 3: + self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way + else: + self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_unet_function_wrapper(self, unet_wrapper_function): + self.model_options["model_function_wrapper"] = unet_wrapper_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_patch_replace(self, patch, name, block_name, number): + to = self.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][(block_name, number)] = patch + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn1", block_name, number) + + def set_model_attn2_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn2", block_name, number) + + def set_model_attn1_output_patch(self, patch): + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch): + self.set_model_patch(patch, "attn2_output_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + patch_list[k] = patch_list[k].to(device) + + def model_dtype(self): + if hasattr(self.model, "get_dtype"): + return self.model.get_dtype() + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + p = set() + for k in patches: + if k in self.model_keys: + p.add(k) + current_patches = self.patches.get(k, []) + current_patches.append((strength_patch, patches[k], strength_model)) + self.patches[k] = current_patches + + return list(p) + + def get_key_patches(self, filter_prefix=None): + model_sd = self.model_state_dict() + p = {} + for k in model_sd: + if filter_prefix is not None: + if not k.startswith(filter_prefix): + continue + if k in self.patches: + p[k] = [model_sd[k]] + self.patches[k] + else: + p[k] = (model_sd[k],) + return p + + def model_state_dict(self, filter_prefix=None): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd + + def patch_model(self, device_to=None): + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", k) + continue + + weight = model_sd[key] + + if key not in self.backup: + self.backup[key] = weight.to(self.offload_device) + + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + comfy.utils.set_attr(self.model, key, out_weight) + del temp_weight + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + return self.model + + def calculate_weight(self, patches, weight, key): + for p in patches: + alpha = p[0] + v = p[1] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key), ) + + if len(v) == 1: + w1 = v[0] + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon + mat1 = v[0].float().to(weight.device) + mat2 = v[1].float().to(weight.device) + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = v[3].float().to(weight.device) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + else: + w1 = w1.float().to(weight.device) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + else: + w2 = w2.float().to(weight.device) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + else: #loha + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) + else: + m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) + m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + + return weight + + def unpatch_model(self, device_to=None): + keys = list(self.backup.keys()) + + for k in keys: + comfy.utils.set_attr(self.model, k, self.backup[k]) + + self.backup = {} + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to diff --git a/comfy/sd.py b/comfy/sd.py index 7462c79e..1edc5fd2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,7 +1,5 @@ import torch import contextlib -import copy -import inspect import math from comfy import model_management @@ -21,6 +19,7 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip +import comfy.model_patcher import comfy.lora import comfy.t2i_adapter.adapter @@ -53,271 +52,6 @@ def load_clip_weights(model, sd): sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return load_model_weights(model, sd) -class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None): - self.size = size - self.model = model - self.patches = {} - self.backup = {} - self.model_options = {"transformer_options":{}} - self.model_size() - self.load_device = load_device - self.offload_device = offload_device - if current_device is None: - self.current_device = self.offload_device - else: - self.current_device = current_device - - def model_size(self): - if self.size > 0: - return self.size - model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size - self.model_keys = set(model_sd.keys()) - return size - - def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - - n.model_options = copy.deepcopy(self.model_options) - n.model_keys = self.model_keys - return n - - def is_clone(self, other): - if hasattr(other, 'model') and self.model is other.model: - return True - return False - - def set_model_sampler_cfg_function(self, sampler_cfg_function): - if len(inspect.signature(sampler_cfg_function).parameters) == 3: - self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way - else: - self.model_options["sampler_cfg_function"] = sampler_cfg_function - - def set_model_unet_function_wrapper(self, unet_wrapper_function): - self.model_options["model_function_wrapper"] = unet_wrapper_function - - def set_model_patch(self, patch, name): - to = self.model_options["transformer_options"] - if "patches" not in to: - to["patches"] = {} - to["patches"][name] = to["patches"].get(name, []) + [patch] - - def set_model_patch_replace(self, patch, name, block_name, number): - to = self.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch - - def set_model_attn1_patch(self, patch): - self.set_model_patch(patch, "attn1_patch") - - def set_model_attn2_patch(self, patch): - self.set_model_patch(patch, "attn2_patch") - - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) - - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) - - def set_model_attn1_output_patch(self, patch): - self.set_model_patch(patch, "attn1_output_patch") - - def set_model_attn2_output_patch(self, patch): - self.set_model_patch(patch, "attn2_output_patch") - - def model_patches_to(self, device): - to = self.model_options["transformer_options"] - if "patches" in to: - patches = to["patches"] - for name in patches: - patch_list = patches[name] - for i in range(len(patch_list)): - if hasattr(patch_list[i], "to"): - patch_list[i] = patch_list[i].to(device) - if "patches_replace" in to: - patches = to["patches_replace"] - for name in patches: - patch_list = patches[name] - for k in patch_list: - if hasattr(patch_list[k], "to"): - patch_list[k] = patch_list[k].to(device) - - def model_dtype(self): - if hasattr(self.model, "get_dtype"): - return self.model.get_dtype() - - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - for k in patches: - if k in self.model_keys: - p.add(k) - current_patches = self.patches.get(k, []) - current_patches.append((strength_patch, patches[k], strength_model)) - self.patches[k] = current_patches - - return list(p) - - def get_key_patches(self, filter_prefix=None): - model_sd = self.model_state_dict() - p = {} - for k in model_sd: - if filter_prefix is not None: - if not k.startswith(filter_prefix): - continue - if k in self.patches: - p[k] = [model_sd[k]] + self.patches[k] - else: - p[k] = (model_sd[k],) - return p - - def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd - - def patch_model(self, device_to=None): - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", k) - continue - - weight = model_sd[key] - - if key not in self.backup: - self.backup[key] = weight.to(self.offload_device) - - if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - - return self.model - - def calculate_weight(self, patches, weight, key): - for p in patches: - alpha = p[0] - v = p[1] - strength_model = p[2] - - if strength_model != 1.0: - weight *= strength_model - - if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key), ) - - if len(v) == 1: - w1 = v[0] - if alpha != 0.0: - if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += alpha * w1.type(weight.dtype).to(weight.device) - elif len(v) == 4: #lora/locon - mat1 = v[0].float().to(weight.device) - mat2 = v[1].float().to(weight.device) - if v[2] is not None: - alpha *= v[2] / mat2.shape[0] - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].float().to(weight.device) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - elif len(v) == 8: #lokr - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) - else: - w1 = w1.float().to(weight.device) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) - else: - w2 = w2.float().to(weight.device) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha *= v[2] / dim - - try: - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - else: #loha - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha *= v[2] / w1b.shape[0] - w2a = v[3] - w2b = v[4] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) - else: - m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) - m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) - - try: - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - - return weight - - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) - - for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) - - self.backup = {} - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = comfy.lora.model_lora_keys_unet(model.model) @@ -355,7 +89,7 @@ class CLIP: self.cond_stage_model = clip(**(params)) self.tokenizer = tokenizer(embedding_directory=embedding_directory) - self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None def clone(self): @@ -573,7 +307,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): #TODO: this function is a mess and should be removed eventually @@ -653,7 +387,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl w.cond_stage_model = clip.cond_stage_model load_clip_weights(w, state_dict) - return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) + return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): sd = comfy.utils.load_torch_file(ckpt_path) @@ -705,7 +439,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -735,7 +469,7 @@ def load_unet(unet_path): #load unet in diffusers format model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()])