Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options)

This commit is contained in:
Jedrzej Kosinski 2025-01-04 01:02:21 -06:00
parent 5a2ad032cb
commit 776aa734e1
2 changed files with 18 additions and 8 deletions

View File

@ -44,7 +44,7 @@ class EnumHookType(enum.Enum):
Patch = "patch" Patch = "patch"
ObjectPatch = "object_patch" ObjectPatch = "object_patch"
AddModels = "add_models" AddModels = "add_models"
Wrappers = "wrappers" TransformerOptions = "transformer_options"
Injections = "add_injections" Injections = "add_injections"
class EnumWeightTarget(enum.Enum): class EnumWeightTarget(enum.Enum):
@ -245,29 +245,39 @@ class AddModelsHook(Hook):
if not self.should_register(model, model_options, target_dict, registered): if not self.should_register(model, model_options, target_dict, registered):
return False return False
class WrapperHook(Hook): class TransformerOptionsHook(Hook):
''' '''
Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
''' '''
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers) super().__init__(hook_type=EnumHookType.TransformerOptions)
self.wrappers_dict = wrappers_dict self.transformers_dict = wrappers_dict
def clone(self, subtype: Callable=None): def clone(self, subtype: Callable=None):
if subtype is None: if subtype is None:
subtype = type(self) subtype = type(self)
c: WrapperHook = super().clone(subtype) c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict c.transformers_dict = self.transformers_dict
return c return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
if not self.should_register(model, model_options, target_dict, registered): if not self.should_register(model, model_options, target_dict, registered):
return False return False
add_model_options = {"transformer_options": self.wrappers_dict} add_model_options = {"transformer_options": self.transformers_dict}
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
if self.hook_scope == EnumHookScope.AllConditioning: if self.hook_scope == EnumHookScope.AllConditioning:
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self) registered.append(self)
return True return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
class WrapperHook(TransformerOptionsHook):
'''
For backwards compatibility, this hook is identical to TransformerOptionsHook.
'''
pass
class SetInjectionsHook(Hook): class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None): def __init__(self, key: str=None, injections: list[PatcherInjection]=None):

View File

@ -945,7 +945,7 @@ class ModelPatcher:
registered_hooks: list[comfy.hooks.Hook] = [] registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided # handle WrapperHooks, if model_options provided
if model_options is not None: if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
hook.add_hook_patches(self, model_options, target_dict, registered_hooks) hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
# handle WeightHooks # handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] weight_hooks_to_register: list[comfy.hooks.WeightHook] = []