Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type

This commit is contained in:
Jedrzej Kosinski 2025-01-04 02:04:07 -06:00
parent 776aa734e1
commit 111fd0cadf
4 changed files with 53 additions and 57 deletions

View File

@ -41,7 +41,6 @@ class EnumHookType(enum.Enum):
Hook types, each of which has different expected behavior. Hook types, each of which has different expected behavior.
''' '''
Weight = "weight" Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch" ObjectPatch = "object_patch"
AddModels = "add_models" AddModels = "add_models"
TransformerOptions = "transformer_options" TransformerOptions = "transformer_options"
@ -194,19 +193,6 @@ class WeightHook(Hook):
c._strength_clip = self._strength_clip c._strength_clip = self._strength_clip
return c return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook): class ObjectPatchHook(Hook):
def __init__(self): def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch) super().__init__(hook_type=EnumHookType.ObjectPatch)
@ -244,6 +230,7 @@ class AddModelsHook(Hook):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
if not self.should_register(model, model_options, target_dict, registered): if not self.should_register(model, model_options, target_dict, registered):
return False return False
return True
class TransformerOptionsHook(Hook): class TransformerOptionsHook(Hook):
''' '''
@ -298,12 +285,28 @@ class SetInjectionsHook(Hook):
pass pass
class HookGroup: class HookGroup:
'''
Stores groups of hooks, and allows them to be queried by type.
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
always use the provided functions on HookGroup.
'''
def __init__(self): def __init__(self):
self.hooks: list[Hook] = [] self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def add(self, hook: Hook): def add(self, hook: Hook):
if hook not in self.hooks: if hook not in self.hooks:
self.hooks.append(hook) self.hooks.append(hook)
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
def remove(self, hook: Hook):
if hook in self.hooks:
self.hooks.remove(hook)
self._hook_dict[hook.hook_type].remove(hook)
def get_type(self, hook_type: EnumHookType):
return self._hook_dict.get(hook_type, [])
def contains(self, hook: Hook): def contains(self, hook: Hook):
return hook in self.hooks return hook in self.hooks
@ -329,36 +332,29 @@ class HookGroup:
for hook in self.hooks: for hook in self.hooks:
hook.hook_keyframe = hook_kf hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self): def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks: # only care about WeightHooks, for now
# only care about WeightHooks, for now for hook in self.get_type(EnumHookType.Weight):
if hook.hook_type == EnumHookType.Weight: hook: WeightHook
hook_schedule = [] hook_schedule = []
# if no hook keyframes, assign default value # if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0: if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None)) hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples # hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = [] all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values(): for range_kfs in scheduled_hooks.values():

View File

@ -940,16 +940,16 @@ class ModelPatcher:
if reset_current_hooks: if reset_current_hooks:
self.patch_hooks(None) self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None):
self.restore_hook_patches() self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = [] registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided # handle TransformerOptionsHooks, if model_options provided
if model_options is not None: if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook.add_hook_patches(self, model_options, target_dict, registered_hooks) hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
# handle WeightHooks # handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches: if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook) weight_hooks_to_register.append(hook)
if len(weight_hooks_to_register) > 0: if len(weight_hooks_to_register) > 0:
@ -958,7 +958,7 @@ class ModelPatcher:
for hook in weight_hooks_to_register: for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target_dict, registered_hooks) hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target_dict) callback(self, hooks, target_dict)
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected(): with self.use_ejected():

View File

@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]] models += [c[model_type]]
return models return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks # get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = [] cnets: list[ControlBase] = []
for c in cond: for c in cond:
if 'hooks' in c: if 'hooks' in c:
for hook in c['hooks'].hooks: for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook full_hooks.add(hook)
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
if 'control' in c: if 'control' in c:
cnets.append(c['control']) cnets.append(c['control'])
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None: if extra_hooks is not None:
for hook in extra_hooks.hooks: for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {}) full_hooks.add(hook)
with_type[hook] = None
return hooks_dict return full_hooks
def convert_cond(cond): def convert_cond(cond):
out = [] out = []
@ -73,7 +70,7 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = [] cnets: list[ControlBase] = []
gligen = [] gligen = []
add_models = [] add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} hooks = comfy.hooks.HookGroup()
for k in conds: for k in conds:
cnets += get_models_from_cond(conds[k], "control") cnets += get_models_from_cond(conds[k], "control")
@ -90,7 +87,10 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype) inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] hook_models = []
for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
x: comfy.hooks.AddModelsHook
hook_models.extend(x.models)
models = control_models + gligen + add_models + hook_models models = control_models + gligen + add_models + hook_models
return models, inference_memory return models, inference_memory
@ -124,7 +124,7 @@ def cleanup_models(conds, models):
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# check for hooks in conds - if not registered, see if can be applied # check for hooks in conds - if not registered, see if can be applied
hooks = {} hooks = comfy.hooks.HookGroup()
for k in conds: for k in conds:
get_hooks_from_cond(conds[k], hooks) get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options # add wrappers and callbacks from ModelPatcher to transformer_options

View File

@ -255,7 +255,7 @@ class SetClipHooks:
clip.use_clip_schedule = schedule_clip clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule: if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None) clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
return (clip,) return (clip,)
class ConditioningTimestepsRange: class ConditioningTimestepsRange: