From 776aa734e1ac0a46fefef6abcc5ad29763003a7e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 01:02:21 -0600 Subject: [PATCH] Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) --- comfy/hooks.py | 24 +++++++++++++++++------- comfy/model_patcher.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 181c4996..7ca3a8a1 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -44,7 +44,7 @@ class EnumHookType(enum.Enum): Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Wrappers = "wrappers" + TransformerOptions = "transformer_options" Injections = "add_injections" class EnumWeightTarget(enum.Enum): @@ -245,29 +245,39 @@ class AddModelsHook(Hook): if not self.should_register(model, model_options, target_dict, registered): return False -class WrapperHook(Hook): +class TransformerOptionsHook(Hook): ''' - Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = wrappers_dict def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict + c.transformers_dict = self.transformers_dict return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} + add_model_options = {"transformer_options": self.transformers_dict} + # TODO: call .to on patches/anything else in transformer_options that is expected to do something if self.hook_scope == EnumHookScope.AllConditioning: comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +class WrapperHook(TransformerOptionsHook): + ''' + For backwards compatibility, this hook is identical to TransformerOptionsHook. + ''' + pass class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 07153552..2db21bdc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -945,7 +945,7 @@ class ModelPatcher: registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = []