diff --git a/comfy/sd.py b/comfy/sd.py index 097fbb20..e016bea0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -357,12 +357,13 @@ class ModelPatcher: self.patches += [(strength_patch, p, strength_model)] return p.keys() - def model_state_dict(self): + def model_state_dict(self, filter_prefix=None): sd = self.model.state_dict() keys = list(sd.keys()) - for k in keys: - if not k.startswith("diffusion_model."): - sd.pop(k) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) return sd def patch_model(self): @@ -443,7 +444,7 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): - model_sd = self.model.state_dict() + model_sd = self.model_state_dict() keys = list(self.backup.keys()) for k in keys: model_sd[k][:] = self.backup[k] diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index daf4b09b..52b73f70 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -14,7 +14,7 @@ class ModelMergeSimple: def merge(self, model1, model2, ratio): m = model1.clone() - sd = model2.model_state_dict() + sd = model2.model_state_dict("diffusion_model.") for k in sd: m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) return (m, ) @@ -35,7 +35,7 @@ class ModelMergeBlocks: def merge(self, model1, model2, **kwargs): m = model1.clone() - sd = model2.model_state_dict() + sd = model2.model_state_dict("diffusion_model.") default_ratio = next(iter(kwargs.values())) for k in sd: