mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options)
This commit is contained in:
parent
5a2ad032cb
commit
776aa734e1
@ -44,7 +44,7 @@ class EnumHookType(enum.Enum):
|
|||||||
Patch = "patch"
|
Patch = "patch"
|
||||||
ObjectPatch = "object_patch"
|
ObjectPatch = "object_patch"
|
||||||
AddModels = "add_models"
|
AddModels = "add_models"
|
||||||
Wrappers = "wrappers"
|
TransformerOptions = "transformer_options"
|
||||||
Injections = "add_injections"
|
Injections = "add_injections"
|
||||||
|
|
||||||
class EnumWeightTarget(enum.Enum):
|
class EnumWeightTarget(enum.Enum):
|
||||||
@ -245,30 +245,40 @@ class AddModelsHook(Hook):
|
|||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
class WrapperHook(Hook):
|
class TransformerOptionsHook(Hook):
|
||||||
'''
|
'''
|
||||||
Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options.
|
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||||
'''
|
'''
|
||||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||||
self.wrappers_dict = wrappers_dict
|
self.transformers_dict = wrappers_dict
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
subtype = type(self)
|
subtype = type(self)
|
||||||
c: WrapperHook = super().clone(subtype)
|
c: WrapperHook = super().clone(subtype)
|
||||||
c.wrappers_dict = self.wrappers_dict
|
c.transformers_dict = self.transformers_dict
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
add_model_options = {"transformer_options": self.transformers_dict}
|
||||||
|
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
|
||||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
registered.append(self)
|
registered.append(self)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||||
|
|
||||||
|
class WrapperHook(TransformerOptionsHook):
|
||||||
|
'''
|
||||||
|
For backwards compatibility, this hook is identical to TransformerOptionsHook.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
class SetInjectionsHook(Hook):
|
class SetInjectionsHook(Hook):
|
||||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
|
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
|
||||||
super().__init__(hook_type=EnumHookType.Injections)
|
super().__init__(hook_type=EnumHookType.Injections)
|
||||||
|
@ -945,7 +945,7 @@ class ModelPatcher:
|
|||||||
registered_hooks: list[comfy.hooks.Hook] = []
|
registered_hooks: list[comfy.hooks.Hook] = []
|
||||||
# handle WrapperHooks, if model_options provided
|
# handle WrapperHooks, if model_options provided
|
||||||
if model_options is not None:
|
if model_options is not None:
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
|
||||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
||||||
# handle WeightHooks
|
# handle WeightHooks
|
||||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
|
Loading…
Reference in New Issue
Block a user