Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable

This commit is contained in:
Jedrzej Kosinski 2025-01-07 00:59:18 -06:00
parent 58bf8815c8
commit 216fea15ee

View File

@ -66,7 +66,7 @@ class _HookRef:
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how should_register function should look like.''' '''Example for how custom_should_register function can look like.'''
return True return True
@ -83,12 +83,17 @@ class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type self.hook_type = hook_type
'''Enum identifying the general class of this hook.'''
self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_ref = hook_ref if hook_ref else _HookRef()
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
self.hook_id = hook_id self.hook_id = hook_id
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
self.hook_scope = hook_scope self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property @property
def strength(self): def strength(self):
@ -109,8 +114,6 @@ class Hook:
c.hook_keyframe = self.hook_keyframe c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c return c
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@ -236,28 +239,34 @@ class TransformerOptionsHook(Hook):
super().__init__(hook_type=EnumHookType.TransformerOptions) super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = transformers_dict self.transformers_dict = transformers_dict
self.hook_scope = hook_scope self.hook_scope = hook_scope
self._skip_adding = False
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
def clone(self): def clone(self):
c: TransformerOptionsHook = super().clone() c: TransformerOptionsHook = super().clone()
c.transformers_dict = self.transformers_dict c.transformers_dict = self.transformers_dict
c._skip_adding = self._skip_adding
return c return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered): if not self.should_register(model, model_options, target_dict, registered):
return False return False
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
self._skip_adding = False
if self.hook_scope == EnumHookScope.AllConditioning: if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict, add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict} "to_load_options": self.transformers_dict}
# skip_adding if included in AllConditioning to avoid double loading
self._skip_adding = True
else: else:
add_model_options = {"to_load_options": self.transformers_dict} add_model_options = {"to_load_options": self.transformers_dict}
# only register if will not be included in AllConditioning to avoid double loading registered.add(self)
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
return True return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) if not self._skip_adding:
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
WrapperHook = TransformerOptionsHook WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''