From 111fd0cadfe83cdda7a1a775f89e0dd675a58d66 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 02:04:07 -0600 Subject: [PATCH] Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type --- comfy/hooks.py | 78 ++++++++++++++++++------------------- comfy/model_patcher.py | 10 ++--- comfy/sampler_helpers.py | 20 +++++----- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 53 insertions(+), 57 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7ca3a8a1..9ccfaa6d 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -41,7 +41,6 @@ class EnumHookType(enum.Enum): Hook types, each of which has different expected behavior. ''' Weight = "weight" - Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" TransformerOptions = "transformer_options" @@ -194,19 +193,6 @@ class WeightHook(Hook): c._strength_clip = self._strength_clip return c -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - class ObjectPatchHook(Hook): def __init__(self): super().__init__(hook_type=EnumHookType.ObjectPatch) @@ -244,6 +230,7 @@ class AddModelsHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False + return True class TransformerOptionsHook(Hook): ''' @@ -298,12 +285,28 @@ class SetInjectionsHook(Hook): pass class HookGroup: + ''' + Stores groups of hooks, and allows them to be queried by type. + + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; + always use the provided functions on HookGroup. + ''' def __init__(self): self.hooks: list[Hook] = [] + self._hook_dict: dict[EnumHookType, list[Hook]] = {} def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) + self._hook_dict.setdefault(hook.hook_type, []).append(hook) + + def remove(self, hook: Hook): + if hook in self.hooks: + self.hooks.remove(hook) + self._hook_dict[hook.hook_type].remove(hook) + + def get_type(self, hook_type: EnumHookType): + return self._hook_dict.get(hook_type, []) def contains(self, hook: Hook): return hook in self.hooks @@ -329,36 +332,29 @@ class HookGroup: for hook in self.hooks: hook.hook_keyframe = hook_kf - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - def get_hooks_for_clip_schedule(self): scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + # only care about WeightHooks, for now + for hook in self.get_type(EnumHookType.Weight): + hook: WeightHook + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule # hooks should not have their schedules in a list of tuples all_ranges: list[tuple[float, float]] = [] for range_kfs in scheduled_hooks.values(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2db21bdc..0430430e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,16 +940,16 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided + # handle TransformerOptionsHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) if len(weight_hooks_to_register) > 0: @@ -958,7 +958,7 @@ class ModelPatcher: for hook in weight_hooks_to_register: hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target_dict) + callback(self, hooks, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 6f21ca3c..abd44cf6 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): +def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup): # get hooks from conds, and collect cnets so they can be checked for extra_hooks cnets: list[ControlBase] = [] for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) if 'control' in c: cnets.append(c['control']) @@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) if extra_hooks is not None: for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) - return hooks_dict + return full_hooks def convert_cond(cond): out = [] @@ -73,7 +70,7 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} + hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") @@ -90,7 +87,10 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] + hook_models = [] + for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + x: comfy.hooks.AddModelsHook + hook_models.extend(x.models) models = control_models + gligen + add_models + hook_models return models, inference_memory @@ -124,7 +124,7 @@ def cleanup_models(conds, models): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # check for hooks in conds - if not registered, see if can be applied - hooks = {} + hooks = comfy.hooks.HookGroup() for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 49b90b9d..64223834 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: