diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090b..181c4996 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -16,46 +16,86 @@ import comfy.model_management import comfy.patcher_extension from node_helpers import conditioning_set_values +# ####################################################################################################### +# Hooks explanation +# ------------------- +# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to +# make explicit special cases like it does for ControlNet and GLIGEN. +# +# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those +# that should run special code when a 'marked' cond is used in sampling. +# ####################################################################################################### + class EnumHookMode(enum.Enum): + ''' + Priority of hook memory optimization vs. speed, mostly related to WeightHooks. + + MinVram: No caching will occur for any operations related to hooks. + MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups. + ''' MinVram = "minvram" MaxSpeed = "maxspeed" class EnumHookType(enum.Enum): + ''' + Hook types, each of which has different expected behavior. + ''' Weight = "weight" Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Callbacks = "callbacks" Wrappers = "wrappers" - SetInjections = "add_injections" + Injections = "add_injections" class EnumWeightTarget(enum.Enum): Model = "model" Clip = "clip" +class EnumHookScope(enum.Enum): + ''' + Determines if hook should be limited in its influence over sampling. + + AllConditioning: hook will affect all conds used in sampling. + HookedOnly: hook will only affect the conds it was attached to. + ''' + AllConditioning = "all_conditioning" + HookedOnly = "hooked_only" + + class _HookRef: pass -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + '''Example for how should_register function should look like.''' return True +def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]: + '''Creates base dictionary for use with Hooks' target param.''' + d = {} + if target is not None: + d['target'] = target + d.update(kwargs) + return d + + class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): + hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False + self.hook_scope = hook_scope @property def strength(self): return self.hook_keyframe.strength - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): self.reset() self.hook_keyframe.initialize_timesteps(model) @@ -75,27 +115,32 @@ class Hook: c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def __eq__(self, other: 'Hook'): + def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref def __hash__(self): return hash(self.hook_ref) class WeightHook(Hook): + ''' + Hook responsible for tracking weights to be applied to some model/clip. + + Note, value of hook_scope is ignored and is treated as HookedOnly. + ''' def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) + super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly) self.weights: dict = None self.weights_clip: dict = None self.need_weight_init = True @@ -110,27 +155,29 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: + + target = target_dict.get('target', None) + if target == EnumWeightTarget.Clip: strength = self._strength_clip + else: + strength = self._strength_model if self.need_weight_init: key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: + if target == EnumWeightTarget.Clip: key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: + if target == EnumWeightTarget.Clip: weights = self.weights_clip + else: + weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) registered.append(self) return True @@ -174,7 +221,12 @@ class ObjectPatchHook(Hook): # TODO: add functionality class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): + ''' + Hook responsible for telling model management any additional models that should be loaded. + + Note, value of hook_scope is ignored and is treated as AllConditioning. + ''' + def __init__(self, key: str=None, models: list[ModelPatcher]=None): super().__init__(hook_type=EnumHookType.AddModels) self.key = key self.models = models @@ -188,24 +240,15 @@ class AddModelsHook(Hook): c.models = self.models.copy() if self.models else self.models c.append_when_same = self.append_when_same return c - # TODO: add functionality -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) - self.key = key - self.callback = callback - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) - c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): + return False class WrapperHook(Hook): + ''' + Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): super().__init__(hook_type=EnumHookType.Wrappers) self.wrappers_dict = wrappers_dict @@ -217,17 +260,18 @@ class WrapperHook(Hook): c.wrappers_dict = self.wrappers_dict return c - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + if self.hook_scope == EnumHookScope.AllConditioning: + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) + def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections @@ -239,7 +283,7 @@ class SetInjectionsHook(Hook): c.injections = self.injections.copy() if self.injections else self.injections return c - def add_hook_injections(self, model: 'ModelPatcher'): + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass @@ -260,14 +304,14 @@ class HookGroup: c.add(hook.clone()) return c - def clone_and_combine(self, other: 'HookGroup'): + def clone_and_combine(self, other: HookGroup): c = self.clone() if other is not None: for hook in other.hooks: c.add(hook.clone()) return c - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): if hook_kf is None: hook_kf = HookKeyframeGroup() else: @@ -336,7 +380,7 @@ class HookGroup: hook.reset() @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup: actual: list[HookGroup] = [] for group in hooks_list: if group is not None: @@ -433,7 +477,7 @@ class HookKeyframeGroup: c._set_first_as_current() return c - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -548,7 +592,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float hook.need_weight_init = False return hook_group -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): +def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True): if model is None: return None patches_model: dict[str, torch.Tensor] = model.model.state_dict() @@ -560,7 +604,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T return patches_model # NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], +def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): key_map = {} if model is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11..07153552 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,13 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): @@ -956,9 +956,9 @@ class ModelPatcher: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) + callback(self, hooks_dict, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index ac973536..6f21ca3c 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -131,4 +131,4 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) + model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 9d9d4837..49b90b9d 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: