mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Document get_attr and get_model_object (#6357)
* Document get_attr and get_model_object * Update model_patcher.py * Update model_patcher.py * Update model_patcher.py
This commit is contained in:
parent
eeab420c70
commit
d055325783
@ -402,7 +402,20 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
def get_model_object(self, name):
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||||
|
"""Retrieves a nested attribute from an object using dot notation considering
|
||||||
|
object patches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the requested attribute
|
||||||
|
|
||||||
|
Example:
|
||||||
|
patcher = ModelPatcher()
|
||||||
|
weight = patcher.get_model_object("layer1.conv.weight")
|
||||||
|
"""
|
||||||
if name in self.object_patches:
|
if name in self.object_patches:
|
||||||
return self.object_patches[name]
|
return self.object_patches[name]
|
||||||
else:
|
else:
|
||||||
|
@ -693,7 +693,25 @@ def copy_to_param(obj, attr, value):
|
|||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1])
|
||||||
prev.data.copy_(value)
|
prev.data.copy_(value)
|
||||||
|
|
||||||
def get_attr(obj, attr):
|
def get_attr(obj, attr: str):
|
||||||
|
"""Retrieves a nested attribute from an object using dot notation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to get the attribute from
|
||||||
|
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the requested attribute
|
||||||
|
|
||||||
|
Example:
|
||||||
|
model = MyModel()
|
||||||
|
weight = get_attr(model, "layer1.conv.weight")
|
||||||
|
# Equivalent to: model.layer1.conv.weight
|
||||||
|
|
||||||
|
Important:
|
||||||
|
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
|
||||||
|
accessing nested model objects under `ModelPatcher.model`.
|
||||||
|
"""
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs:
|
for name in attrs:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
|
Loading…
Reference in New Issue
Block a user