mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable
This commit is contained in:
parent
58bf8815c8
commit
216fea15ee
@ -66,7 +66,7 @@ class _HookRef:
|
|||||||
|
|
||||||
|
|
||||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
'''Example for how should_register function should look like.'''
|
'''Example for how custom_should_register function can look like.'''
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -83,12 +83,17 @@ class Hook:
|
|||||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||||
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||||
self.hook_type = hook_type
|
self.hook_type = hook_type
|
||||||
|
'''Enum identifying the general class of this hook.'''
|
||||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||||
|
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||||
self.hook_id = hook_id
|
self.hook_id = hook_id
|
||||||
|
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||||
|
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||||
self.hook_scope = hook_scope
|
self.hook_scope = hook_scope
|
||||||
|
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
self.auto_apply_to_nonpositive = False
|
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
@ -109,8 +114,6 @@ class Hook:
|
|||||||
c.hook_keyframe = self.hook_keyframe
|
c.hook_keyframe = self.hook_keyframe
|
||||||
c.hook_scope = self.hook_scope
|
c.hook_scope = self.hook_scope
|
||||||
c.custom_should_register = self.custom_should_register
|
c.custom_should_register = self.custom_should_register
|
||||||
# TODO: make this do something
|
|
||||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
@ -236,28 +239,34 @@ class TransformerOptionsHook(Hook):
|
|||||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||||
self.transformers_dict = transformers_dict
|
self.transformers_dict = transformers_dict
|
||||||
self.hook_scope = hook_scope
|
self.hook_scope = hook_scope
|
||||||
|
self._skip_adding = False
|
||||||
|
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c: TransformerOptionsHook = super().clone()
|
c: TransformerOptionsHook = super().clone()
|
||||||
c.transformers_dict = self.transformers_dict
|
c.transformers_dict = self.transformers_dict
|
||||||
|
c._skip_adding = self._skip_adding
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||||
|
self._skip_adding = False
|
||||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||||
add_model_options = {"transformer_options": self.transformers_dict,
|
add_model_options = {"transformer_options": self.transformers_dict,
|
||||||
"to_load_options": self.transformers_dict}
|
"to_load_options": self.transformers_dict}
|
||||||
|
# skip_adding if included in AllConditioning to avoid double loading
|
||||||
|
self._skip_adding = True
|
||||||
else:
|
else:
|
||||||
add_model_options = {"to_load_options": self.transformers_dict}
|
add_model_options = {"to_load_options": self.transformers_dict}
|
||||||
# only register if will not be included in AllConditioning to avoid double loading
|
registered.add(self)
|
||||||
registered.add(self)
|
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
if not self._skip_adding:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||||
|
|
||||||
WrapperHook = TransformerOptionsHook
|
WrapperHook = TransformerOptionsHook
|
||||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||||
|
Loading…
Reference in New Issue
Block a user