Document get_attr and get_model_object (#6357)

* Document get_attr and get_model_object

* Update model_patcher.py

* Update model_patcher.py

* Update model_patcher.py
This commit is contained in:
Chenlei Hu 2025-01-06 20:12:22 -05:00 committed by GitHub
parent eeab420c70
commit d055325783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 2 deletions

View File

@ -402,7 +402,20 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
Args:
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
patcher = ModelPatcher()
weight = patcher.get_model_object("layer1.conv.weight")
"""
if name in self.object_patches:
return self.object_patches[name]
else:

View File

@ -693,7 +693,25 @@ def copy_to_param(obj, attr, value):
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.
Args:
obj: The object to get the attribute from
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
model = MyModel()
weight = get_attr(model, "layer1.conv.weight")
# Equivalent to: model.layer1.conv.weight
Important:
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
accessing nested model objects under `ModelPatcher.model`.
"""
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)