From 51581dbfa9ce19f537e4cd110509ac5ab91dd74c Mon Sep 17 00:00:00 2001
From: comfyanonymous <comfyanonymous@protonmail.com>
Date: Tue, 20 Jun 2023 19:37:43 -0400
Subject: [PATCH] Fix last commits causing an issue with the text encoder lora.

---
 comfy/sd.py                         | 11 ++++++-----
 comfy_extras/nodes_model_merging.py |  4 ++--
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/comfy/sd.py b/comfy/sd.py
index 097fbb200..e016bea07 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -357,12 +357,13 @@ class ModelPatcher:
         self.patches += [(strength_patch, p, strength_model)]
         return p.keys()
 
-    def model_state_dict(self):
+    def model_state_dict(self, filter_prefix=None):
         sd = self.model.state_dict()
         keys = list(sd.keys())
-        for k in keys:
-            if not k.startswith("diffusion_model."):
-                sd.pop(k)
+        if filter_prefix is not None:
+            for k in keys:
+                if not k.startswith(filter_prefix):
+                    sd.pop(k)
         return sd
 
     def patch_model(self):
@@ -443,7 +444,7 @@ class ModelPatcher:
                     weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
         return self.model
     def unpatch_model(self):
-        model_sd = self.model.state_dict()
+        model_sd = self.model_state_dict()
         keys = list(self.backup.keys())
         for k in keys:
             model_sd[k][:] = self.backup[k]
diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py
index daf4b09ba..52b73f702 100644
--- a/comfy_extras/nodes_model_merging.py
+++ b/comfy_extras/nodes_model_merging.py
@@ -14,7 +14,7 @@ class ModelMergeSimple:
 
     def merge(self, model1, model2, ratio):
         m = model1.clone()
-        sd = model2.model_state_dict()
+        sd = model2.model_state_dict("diffusion_model.")
         for k in sd:
             m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
         return (m, )
@@ -35,7 +35,7 @@ class ModelMergeBlocks:
 
     def merge(self, model1, model2, **kwargs):
         m = model1.clone()
-        sd = model2.model_state_dict()
+        sd = model2.model_state_dict("diffusion_model.")
         default_ratio = next(iter(kwargs.values()))
 
         for k in sd: