diff --git a/comfy/hooks.py b/comfy/hooks.py index 46fc06bd..7c2f6689 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -66,7 +66,7 @@ class _HookRef: def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - '''Example for how should_register function should look like.''' + '''Example for how custom_should_register function can look like.''' return True @@ -83,12 +83,17 @@ class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type + '''Enum identifying the general class of this hook.''' self.hook_ref = hook_ref if hook_ref else _HookRef() + '''Reference shared between hook clones that have the same value. Should NOT be modified.''' self.hook_id = hook_id + '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.''' self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + '''Keyframe storage that can be referenced to get strength for current sampling step.''' self.hook_scope = hook_scope + '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False + '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): @@ -109,8 +114,6 @@ class Hook: c.hook_keyframe = self.hook_keyframe c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,28 +239,34 @@ class TransformerOptionsHook(Hook): super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = transformers_dict self.hook_scope = hook_scope + self._skip_adding = False + '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.''' def clone(self): c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict + c._skip_adding = self._skip_adding return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks + self._skip_adding = False if self.hook_scope == EnumHookScope.AllConditioning: add_model_options = {"transformer_options": self.transformers_dict, "to_load_options": self.transformers_dict} + # skip_adding if included in AllConditioning to avoid double loading + self._skip_adding = True else: add_model_options = {"to_load_options": self.transformers_dict} - # only register if will not be included in AllConditioning to avoid double loading - registered.add(self) + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): - comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + if not self._skip_adding: + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''