mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options)
This commit is contained in:
parent
5a2ad032cb
commit
776aa734e1
@ -44,7 +44,7 @@ class EnumHookType(enum.Enum):
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Wrappers = "wrappers"
|
||||
TransformerOptions = "transformer_options"
|
||||
Injections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
@ -245,29 +245,39 @@ class AddModelsHook(Hook):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
|
||||
class WrapperHook(Hook):
|
||||
class TransformerOptionsHook(Hook):
|
||||
'''
|
||||
Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options.
|
||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||
'''
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||
self.transformers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
c.transformers_dict = self.transformers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
add_model_options = {"transformer_options": self.transformers_dict}
|
||||
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
|
||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
return True
|
||||
|
||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||
|
||||
class WrapperHook(TransformerOptionsHook):
|
||||
'''
|
||||
For backwards compatibility, this hook is identical to TransformerOptionsHook.
|
||||
'''
|
||||
pass
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
|
||||
|
@ -945,7 +945,7 @@ class ModelPatcher:
|
||||
registered_hooks: list[comfy.hooks.Hook] = []
|
||||
# handle WrapperHooks, if model_options provided
|
||||
if model_options is not None:
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
|
||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
||||
# handle WeightHooks
|
||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||
|
Loading…
Reference in New Issue
Block a user