Move ModelPatcher to model_patcher.py

This commit is contained in:
comfyanonymous 2023-08-28 14:49:18 -04:00
parent 4798cf5a62
commit f92074b84f
3 changed files with 278 additions and 274 deletions

View File

@ -1,9 +1,9 @@
import torch import torch
import math import math
import comfy.utils import comfy.utils
import comfy.sd
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
import comfy.model_patcher
import comfy.cldm.cldm import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
@ -129,7 +129,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):

270
comfy/model_patcher.py Normal file
View File

@ -0,0 +1,270 @@
import torch
import copy
import inspect
import comfy.utils
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to

View File

@ -1,7 +1,5 @@
import torch import torch
import contextlib import contextlib
import copy
import inspect
import math import math
from comfy import model_management from comfy import model_management
@ -21,6 +19,7 @@ from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
@ -53,271 +52,6 @@ def load_clip_weights(model, sd):
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return load_model_weights(model, sd) return load_model_weights(model, sd)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = comfy.lora.model_lora_keys_unet(model.model) key_map = comfy.lora.model_lora_keys_unet(model.model)
@ -355,7 +89,7 @@ class CLIP:
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None self.layer_idx = None
def clone(self): def clone(self):
@ -573,7 +307,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
if model_management.should_use_fp16(): if model_management.should_use_fp16():
model = model.half() model = model.half()
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually #TODO: this function is a mess and should be removed eventually
@ -653,7 +387,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_clip_weights(w, state_dict) load_clip_weights(w, state_dict)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
@ -705,7 +439,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU") print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
@ -735,7 +469,7 @@ def load_unet(unet_path): #load unet in diffusers format
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) model_management.load_models_gpu([model, clip.load_model()])