diff --git a/comfy/hooks.py b/comfy/hooks.py index 25d67b86..b62092cc 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -101,10 +101,8 @@ class Hook: def reset(self): self.hook_keyframe.reset() - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() + def clone(self): + c: Hook = self.__class__() c.hook_type = self.hook_type c.hook_ref = self.hook_ref c.hook_id = self.hook_id @@ -182,10 +180,8 @@ class WeightHook(Hook): return True # TODO: add logs about any keys that were not applied - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) + def clone(self): + c: WeightHook = super().clone() c.weights = self.weights c.weights_clip = self.weights_clip c.need_weight_init = self.need_weight_init @@ -194,17 +190,21 @@ class WeightHook(Hook): return c class ObjectPatchHook(Hook): - def __init__(self): + def __init__(self, object_patches: dict[str]=None): super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None + self.object_patches = object_patches - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) + def clone(self): + c: ObjectPatchHook = super().clone() c.object_patches = self.object_patches return c - # TODO: add functionality + + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True class AddModelsHook(Hook): ''' @@ -219,12 +219,10 @@ class AddModelsHook(Hook): self.append_when_same = True '''Curently does nothing.''' - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key + def clone(self): + c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models + c.key = self.key c.append_when_same = self.append_when_same return c @@ -242,10 +240,8 @@ class TransformerOptionsHook(Hook): super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = wrappers_dict - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) + def clone(self): + c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict return c @@ -265,11 +261,8 @@ class TransformerOptionsHook(Hook): def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) -class WrapperHook(TransformerOptionsHook): - ''' - For backwards compatibility, this hook is identical to TransformerOptionsHook. - ''' - pass +WrapperHook = TransformerOptionsHook +'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): @@ -277,14 +270,19 @@ class SetInjectionsHook(Hook): self.key = key self.injections = injections - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) + def clone(self): + c: SetInjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index cb938851..d43280fe 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -157,4 +157,3 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options -