Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options)

This commit is contained in:
Jedrzej Kosinski 2025-01-04 01:02:21 -06:00
parent 5a2ad032cb
commit 776aa734e1
2 changed files with 18 additions and 8 deletions

View File

@ -44,7 +44,7 @@ class EnumHookType(enum.Enum):
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Wrappers = "wrappers"
TransformerOptions = "transformer_options"
Injections = "add_injections"
class EnumWeightTarget(enum.Enum):
@ -245,30 +245,40 @@ class AddModelsHook(Hook):
if not self.should_register(model, model_options, target_dict, registered):
return False
class WrapperHook(Hook):
class TransformerOptionsHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options.
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
c.transformers_dict = self.transformers_dict
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
add_model_options = {"transformer_options": self.transformers_dict}
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
if self.hook_scope == EnumHookScope.AllConditioning:
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
class WrapperHook(TransformerOptionsHook):
'''
For backwards compatibility, this hook is identical to TransformerOptionsHook.
'''
pass
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
super().__init__(hook_type=EnumHookType.Injections)

View File

@ -945,7 +945,7 @@ class ModelPatcher:
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []