Fix merging not working when model2 of model merge node was a merge.

This commit is contained in:
comfyanonymous 2023-07-08 22:16:40 -04:00
parent febea8c101
commit a9a4ba7574
2 changed files with 114 additions and 87 deletions

View File

@ -206,7 +206,7 @@ class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0):
self.size = size self.size = size
self.model = model self.model = model
self.patches = [] self.patches = {}
self.backup = {} self.backup = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size() self.model_size()
@ -227,7 +227,10 @@ class ModelPatcher:
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
n.patches = self.patches[:] n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys n.model_keys = self.model_keys
return n return n
@ -295,12 +298,28 @@ class ModelPatcher:
return self.model.get_dtype() return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = {} p = set()
for k in patches: for k in patches:
if k in self.model_keys: if k in self.model_keys:
p[k] = patches[k] p.add(k)
self.patches += [(strength_patch, p, strength_model)] current_patches = self.patches.get(k, [])
return p.keys() current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict() sd = self.model.state_dict()
@ -313,24 +332,31 @@ class ModelPatcher:
def patch_model(self): def patch_model(self):
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for p in self.patches: for key in self.patches:
for k in p[1]:
v = p[1][k]
key = k
if key not in model_sd: if key not in model_sd:
print("could not patch. key doesn't exist in model:", k) print("could not patch. key doesn't exist in model:", k)
continue continue
weight = model_sd[key] weight = model_sd[key]
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.clone() self.backup[key] = weight.clone()
weight[:] = self.calculate_weight(self.patches[key], weight.clone(), key)
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0] alpha = p[0]
v = p[1]
strength_model = p[2] strength_model = p[2]
if strength_model != 1.0: if strength_model != 1.0:
weight *= strength_model weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1: if len(v) == 1:
w1 = v[0] w1 = v[0]
if w1.shape != weight.shape: if w1.shape != weight.shape:
@ -391,7 +417,8 @@ class ModelPatcher:
m2 = torch.mm(w2a.float(), w2b.float()) m2 = torch.mm(w2a.float(), w2b.float())
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
return self.model return weight
def unpatch_model(self): def unpatch_model(self):
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
keys = list(self.backup.keys()) keys = list(self.backup.keys())

View File

@ -18,9 +18,9 @@ class ModelMergeSimple:
def merge(self, model1, model2, ratio): def merge(self, model1, model2, ratio):
m = model1.clone() m = model1.clone()
sd = model2.model_state_dict("diffusion_model.") kp = model2.get_key_patches("diffusion_model.")
for k in sd: for k in kp:
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
class ModelMergeBlocks: class ModelMergeBlocks:
@ -39,10 +39,10 @@ class ModelMergeBlocks:
def merge(self, model1, model2, **kwargs): def merge(self, model1, model2, **kwargs):
m = model1.clone() m = model1.clone()
sd = model2.model_state_dict("diffusion_model.") kp = model2.get_key_patches("diffusion_model.")
default_ratio = next(iter(kwargs.values())) default_ratio = next(iter(kwargs.values()))
for k in sd: for k in kp:
ratio = default_ratio ratio = default_ratio
k_unet = k[len("diffusion_model."):] k_unet = k[len("diffusion_model."):]
@ -52,7 +52,7 @@ class ModelMergeBlocks:
ratio = kwargs[arg] ratio = kwargs[arg]
last_arg_size = len(arg) last_arg_size = len(arg)
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
class CheckpointSave: class CheckpointSave: