From 1abf8374ec690972ba512a50eeaa767935536441 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Mar 2024 17:27:23 -0500 Subject: [PATCH] utils.set_attr can now be used to set any attribute. The old set_attr has been renamed to set_attr_param. --- comfy/controlnet.py | 4 ++-- comfy/model_patcher.py | 7 +++---- comfy/utils.py | 7 +++++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 41619758..f859a50d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -287,13 +287,13 @@ class ControlLora(ControlNet): for k in sd: weight = sd[k] try: - comfy.utils.set_attr(self.control_model, k, weight) + comfy.utils.set_attr_param(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f29781f3..604e3477 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -176,10 +176,9 @@ class ModelPatcher: def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = comfy.utils.get_attr(self.model, k) + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - comfy.utils.set_attr(self.model, k, self.object_patches[k]) if patch_weights: model_sd = self.model_state_dict() @@ -203,7 +202,7 @@ class ModelPatcher: if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: - comfy.utils.set_attr(self.model, key, out_weight) + comfy.utils.set_attr_param(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -342,7 +341,7 @@ class ModelPatcher: comfy.utils.copy_to_param(self.model, k, self.backup[k]) else: for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + comfy.utils.set_attr_param(self.model, k, self.backup[k]) self.backup = {} diff --git a/comfy/utils.py b/comfy/utils.py index 41f730c8..5deb14cd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -294,8 +294,11 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) - del prev + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it