From d055325783ef30956d5af8a2b0add775c37caa6d Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 6 Jan 2025 20:12:22 -0500 Subject: [PATCH] Document get_attr and get_model_object (#6357) * Document get_attr and get_model_object * Update model_patcher.py * Update model_patcher.py * Update model_patcher.py --- comfy/model_patcher.py | 15 ++++++++++++++- comfy/utils.py | 20 +++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11..e886bdbb 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -402,7 +402,20 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj - def get_model_object(self, name): + def get_model_object(self, name: str) -> torch.nn.Module: + """Retrieves a nested attribute from an object using dot notation considering + object patches. + + Args: + name (str): The attribute path using dot notation (e.g. "model.layer.weight") + + Returns: + The value of the requested attribute + + Example: + patcher = ModelPatcher() + weight = patcher.get_model_object("layer1.conv.weight") + """ if name in self.object_patches: return self.object_patches[name] else: diff --git a/comfy/utils.py b/comfy/utils.py index ea666ae5..b486b2de 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -693,7 +693,25 @@ def copy_to_param(obj, attr, value): prev = getattr(obj, attrs[-1]) prev.data.copy_(value) -def get_attr(obj, attr): +def get_attr(obj, attr: str): + """Retrieves a nested attribute from an object using dot notation. + + Args: + obj: The object to get the attribute from + attr (str): The attribute path using dot notation (e.g. "model.layer.weight") + + Returns: + The value of the requested attribute + + Example: + model = MyModel() + weight = get_attr(model, "layer1.conv.weight") + # Equivalent to: model.layer1.conv.weight + + Important: + Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when + accessing nested model objects under `ModelPatcher.model`. + """ attrs = attr.split(".") for name in attrs: obj = getattr(obj, name)