mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type
This commit is contained in:
parent
776aa734e1
commit
111fd0cadf
@ -41,7 +41,6 @@ class EnumHookType(enum.Enum):
|
|||||||
Hook types, each of which has different expected behavior.
|
Hook types, each of which has different expected behavior.
|
||||||
'''
|
'''
|
||||||
Weight = "weight"
|
Weight = "weight"
|
||||||
Patch = "patch"
|
|
||||||
ObjectPatch = "object_patch"
|
ObjectPatch = "object_patch"
|
||||||
AddModels = "add_models"
|
AddModels = "add_models"
|
||||||
TransformerOptions = "transformer_options"
|
TransformerOptions = "transformer_options"
|
||||||
@ -194,19 +193,6 @@ class WeightHook(Hook):
|
|||||||
c._strength_clip = self._strength_clip
|
c._strength_clip = self._strength_clip
|
||||||
return c
|
return c
|
||||||
|
|
||||||
class PatchHook(Hook):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(hook_type=EnumHookType.Patch)
|
|
||||||
self.patches: dict = None
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: PatchHook = super().clone(subtype)
|
|
||||||
c.patches = self.patches
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class ObjectPatchHook(Hook):
|
class ObjectPatchHook(Hook):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
@ -244,6 +230,7 @@ class AddModelsHook(Hook):
|
|||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
class TransformerOptionsHook(Hook):
|
class TransformerOptionsHook(Hook):
|
||||||
'''
|
'''
|
||||||
@ -298,12 +285,28 @@ class SetInjectionsHook(Hook):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class HookGroup:
|
class HookGroup:
|
||||||
|
'''
|
||||||
|
Stores groups of hooks, and allows them to be queried by type.
|
||||||
|
|
||||||
|
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||||
|
always use the provided functions on HookGroup.
|
||||||
|
'''
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hooks: list[Hook] = []
|
self.hooks: list[Hook] = []
|
||||||
|
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||||
|
|
||||||
def add(self, hook: Hook):
|
def add(self, hook: Hook):
|
||||||
if hook not in self.hooks:
|
if hook not in self.hooks:
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||||
|
|
||||||
|
def remove(self, hook: Hook):
|
||||||
|
if hook in self.hooks:
|
||||||
|
self.hooks.remove(hook)
|
||||||
|
self._hook_dict[hook.hook_type].remove(hook)
|
||||||
|
|
||||||
|
def get_type(self, hook_type: EnumHookType):
|
||||||
|
return self._hook_dict.get(hook_type, [])
|
||||||
|
|
||||||
def contains(self, hook: Hook):
|
def contains(self, hook: Hook):
|
||||||
return hook in self.hooks
|
return hook in self.hooks
|
||||||
@ -329,36 +332,29 @@ class HookGroup:
|
|||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
hook.hook_keyframe = hook_kf
|
hook.hook_keyframe = hook_kf
|
||||||
|
|
||||||
def get_dict_repr(self):
|
|
||||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
|
||||||
for hook in self.hooks:
|
|
||||||
with_type = d.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_hooks_for_clip_schedule(self):
|
def get_hooks_for_clip_schedule(self):
|
||||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||||
for hook in self.hooks:
|
# only care about WeightHooks, for now
|
||||||
# only care about WeightHooks, for now
|
for hook in self.get_type(EnumHookType.Weight):
|
||||||
if hook.hook_type == EnumHookType.Weight:
|
hook: WeightHook
|
||||||
hook_schedule = []
|
hook_schedule = []
|
||||||
# if no hook keyframes, assign default value
|
# if no hook keyframes, assign default value
|
||||||
if len(hook.hook_keyframe.keyframes) == 0:
|
if len(hook.hook_keyframe.keyframes) == 0:
|
||||||
hook_schedule.append(((0.0, 1.0), None))
|
hook_schedule.append(((0.0, 1.0), None))
|
||||||
scheduled_hooks[hook] = hook_schedule
|
|
||||||
continue
|
|
||||||
# find ranges of values
|
|
||||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
|
||||||
for keyframe in hook.hook_keyframe.keyframes:
|
|
||||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
# create final range, assuming last start_percent was not 1.0
|
|
||||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
|
||||||
scheduled_hooks[hook] = hook_schedule
|
scheduled_hooks[hook] = hook_schedule
|
||||||
|
continue
|
||||||
|
# find ranges of values
|
||||||
|
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||||
|
for keyframe in hook.hook_keyframe.keyframes:
|
||||||
|
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
# create final range, assuming last start_percent was not 1.0
|
||||||
|
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||||
|
scheduled_hooks[hook] = hook_schedule
|
||||||
# hooks should not have their schedules in a list of tuples
|
# hooks should not have their schedules in a list of tuples
|
||||||
all_ranges: list[tuple[float, float]] = []
|
all_ranges: list[tuple[float, float]] = []
|
||||||
for range_kfs in scheduled_hooks.values():
|
for range_kfs in scheduled_hooks.values():
|
||||||
|
@ -940,16 +940,16 @@ class ModelPatcher:
|
|||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None):
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None):
|
||||||
self.restore_hook_patches()
|
self.restore_hook_patches()
|
||||||
registered_hooks: list[comfy.hooks.Hook] = []
|
registered_hooks: list[comfy.hooks.Hook] = []
|
||||||
# handle WrapperHooks, if model_options provided
|
# handle TransformerOptionsHooks, if model_options provided
|
||||||
if model_options is not None:
|
if model_options is not None:
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
||||||
# handle WeightHooks
|
# handle WeightHooks
|
||||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||||
if hook.hook_ref not in self.hook_patches:
|
if hook.hook_ref not in self.hook_patches:
|
||||||
weight_hooks_to_register.append(hook)
|
weight_hooks_to_register.append(hook)
|
||||||
if len(weight_hooks_to_register) > 0:
|
if len(weight_hooks_to_register) > 0:
|
||||||
@ -958,7 +958,7 @@ class ModelPatcher:
|
|||||||
for hook in weight_hooks_to_register:
|
for hook in weight_hooks_to_register:
|
||||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||||
callback(self, hooks_dict, target_dict)
|
callback(self, hooks, target_dict)
|
||||||
|
|
||||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[model_type]]
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
|
||||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if 'hooks' in c:
|
if 'hooks' in c:
|
||||||
for hook in c['hooks'].hooks:
|
for hook in c['hooks'].hooks:
|
||||||
hook: comfy.hooks.Hook
|
full_hooks.add(hook)
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
if 'control' in c:
|
if 'control' in c:
|
||||||
cnets.append(c['control'])
|
cnets.append(c['control'])
|
||||||
|
|
||||||
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
|||||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||||
if extra_hooks is not None:
|
if extra_hooks is not None:
|
||||||
for hook in extra_hooks.hooks:
|
for hook in extra_hooks.hooks:
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
full_hooks.add(hook)
|
||||||
with_type[hook] = None
|
|
||||||
|
|
||||||
return hooks_dict
|
return full_hooks
|
||||||
|
|
||||||
def convert_cond(cond):
|
def convert_cond(cond):
|
||||||
out = []
|
out = []
|
||||||
@ -73,7 +70,7 @@ def get_additional_models(conds, dtype):
|
|||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
add_models = []
|
add_models = []
|
||||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
hooks = comfy.hooks.HookGroup()
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
@ -90,7 +87,10 @@ def get_additional_models(conds, dtype):
|
|||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
hook_models = []
|
||||||
|
for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
|
||||||
|
x: comfy.hooks.AddModelsHook
|
||||||
|
hook_models.extend(x.models)
|
||||||
models = control_models + gligen + add_models + hook_models
|
models = control_models + gligen + add_models + hook_models
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
@ -124,7 +124,7 @@ def cleanup_models(conds, models):
|
|||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||||
# check for hooks in conds - if not registered, see if can be applied
|
# check for hooks in conds - if not registered, see if can be applied
|
||||||
hooks = {}
|
hooks = comfy.hooks.HookGroup()
|
||||||
for k in conds:
|
for k in conds:
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
|
@ -255,7 +255,7 @@ class SetClipHooks:
|
|||||||
clip.use_clip_schedule = schedule_clip
|
clip.use_clip_schedule = schedule_clip
|
||||||
if not clip.use_clip_schedule:
|
if not clip.use_clip_schedule:
|
||||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class ConditioningTimestepsRange:
|
class ConditioningTimestepsRange:
|
||||||
|
Loading…
Reference in New Issue
Block a user