From 72bbf493494109d3c177ad5de378c6c4bbae61d1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 29 Dec 2024 15:49:09 -0600 Subject: [PATCH 01/15] Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' --- comfy/hooks.py | 21 +++++++++++++++++---- comfy/model_patcher.py | 5 +++-- comfy/samplers.py | 10 ++++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cf33598a..79a7090b 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -366,9 +366,15 @@ class HookKeyframe: self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): c = HookKeyframe(strength=self.strength, - start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) c.start_t = self.start_t return c @@ -408,6 +414,12 @@ class HookKeyframeGroup: else: self._current_keyframe = None + def has_guarantee_steps(self): + for kf in self.keyframes: + if kf.guarantee_steps > 0: + return True + return False + def has_index(self, index: int): return index >= 0 and index < len(self.keyframes) @@ -425,15 +437,16 @@ class HookKeyframeGroup: for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, curr_t: float) -> bool: + def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool: if self.is_empty(): return False if curr_t == self._curr_t: return False + max_sigma = torch.max(transformer_options["sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: + if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need to switch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.keyframes)): @@ -446,7 +459,7 @@ class HookKeyframeGroup: self._current_keyframe = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: + if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: break diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d89d9a6a..4597ce11 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -919,11 +919,12 @@ class ModelPatcher: def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode - def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: - changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) + changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: diff --git a/comfy/samplers.py b/comfy/samplers.py index 27686722..6a386511 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -144,7 +144,7 @@ def cond_cat(c_list): return out -def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): +def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options): # need to figure out remaining unmasked area for conds default_mults = [] for _ in default_conds: @@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H # replace p's mult with calculated mult p = p._replace(mult=mult) if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] @@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if p is None: continue if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] default_conds.append(default_c) @@ -840,7 +840,9 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} + extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) + extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( sampler.sample, From 5a2ad032cb09afcaf7fadf5cdfa20c2b0498aee5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 3 Jan 2025 20:02:27 -0600 Subject: [PATCH 02/15] Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed --- comfy/hooks.py | 148 +++++++++++++++++++++++------------- comfy/model_patcher.py | 8 +- comfy/sampler_helpers.py | 2 +- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 102 insertions(+), 58 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090b..181c4996 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -16,46 +16,86 @@ import comfy.model_management import comfy.patcher_extension from node_helpers import conditioning_set_values +# ####################################################################################################### +# Hooks explanation +# ------------------- +# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to +# make explicit special cases like it does for ControlNet and GLIGEN. +# +# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those +# that should run special code when a 'marked' cond is used in sampling. +# ####################################################################################################### + class EnumHookMode(enum.Enum): + ''' + Priority of hook memory optimization vs. speed, mostly related to WeightHooks. + + MinVram: No caching will occur for any operations related to hooks. + MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups. + ''' MinVram = "minvram" MaxSpeed = "maxspeed" class EnumHookType(enum.Enum): + ''' + Hook types, each of which has different expected behavior. + ''' Weight = "weight" Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Callbacks = "callbacks" Wrappers = "wrappers" - SetInjections = "add_injections" + Injections = "add_injections" class EnumWeightTarget(enum.Enum): Model = "model" Clip = "clip" +class EnumHookScope(enum.Enum): + ''' + Determines if hook should be limited in its influence over sampling. + + AllConditioning: hook will affect all conds used in sampling. + HookedOnly: hook will only affect the conds it was attached to. + ''' + AllConditioning = "all_conditioning" + HookedOnly = "hooked_only" + + class _HookRef: pass -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + '''Example for how should_register function should look like.''' return True +def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]: + '''Creates base dictionary for use with Hooks' target param.''' + d = {} + if target is not None: + d['target'] = target + d.update(kwargs) + return d + + class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): + hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False + self.hook_scope = hook_scope @property def strength(self): return self.hook_keyframe.strength - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): self.reset() self.hook_keyframe.initialize_timesteps(model) @@ -75,27 +115,32 @@ class Hook: c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def __eq__(self, other: 'Hook'): + def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref def __hash__(self): return hash(self.hook_ref) class WeightHook(Hook): + ''' + Hook responsible for tracking weights to be applied to some model/clip. + + Note, value of hook_scope is ignored and is treated as HookedOnly. + ''' def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) + super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly) self.weights: dict = None self.weights_clip: dict = None self.need_weight_init = True @@ -110,27 +155,29 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: + + target = target_dict.get('target', None) + if target == EnumWeightTarget.Clip: strength = self._strength_clip + else: + strength = self._strength_model if self.need_weight_init: key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: + if target == EnumWeightTarget.Clip: key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: + if target == EnumWeightTarget.Clip: weights = self.weights_clip + else: + weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) registered.append(self) return True @@ -174,7 +221,12 @@ class ObjectPatchHook(Hook): # TODO: add functionality class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): + ''' + Hook responsible for telling model management any additional models that should be loaded. + + Note, value of hook_scope is ignored and is treated as AllConditioning. + ''' + def __init__(self, key: str=None, models: list[ModelPatcher]=None): super().__init__(hook_type=EnumHookType.AddModels) self.key = key self.models = models @@ -188,24 +240,15 @@ class AddModelsHook(Hook): c.models = self.models.copy() if self.models else self.models c.append_when_same = self.append_when_same return c - # TODO: add functionality -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) - self.key = key - self.callback = callback - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) - c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): + return False class WrapperHook(Hook): + ''' + Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): super().__init__(hook_type=EnumHookType.Wrappers) self.wrappers_dict = wrappers_dict @@ -217,17 +260,18 @@ class WrapperHook(Hook): c.wrappers_dict = self.wrappers_dict return c - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + if self.hook_scope == EnumHookScope.AllConditioning: + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) + def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections @@ -239,7 +283,7 @@ class SetInjectionsHook(Hook): c.injections = self.injections.copy() if self.injections else self.injections return c - def add_hook_injections(self, model: 'ModelPatcher'): + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass @@ -260,14 +304,14 @@ class HookGroup: c.add(hook.clone()) return c - def clone_and_combine(self, other: 'HookGroup'): + def clone_and_combine(self, other: HookGroup): c = self.clone() if other is not None: for hook in other.hooks: c.add(hook.clone()) return c - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): if hook_kf is None: hook_kf = HookKeyframeGroup() else: @@ -336,7 +380,7 @@ class HookGroup: hook.reset() @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup: actual: list[HookGroup] = [] for group in hooks_list: if group is not None: @@ -433,7 +477,7 @@ class HookKeyframeGroup: c._set_first_as_current() return c - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -548,7 +592,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float hook.need_weight_init = False return hook_group -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): +def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True): if model is None: return None patches_model: dict[str, torch.Tensor] = model.model.state_dict() @@ -560,7 +604,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T return patches_model # NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], +def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): key_map = {} if model is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11..07153552 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,13 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): @@ -956,9 +956,9 @@ class ModelPatcher: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) + callback(self, hooks_dict, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index ac973536..6f21ca3c 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -131,4 +131,4 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) + model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 9d9d4837..49b90b9d 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 776aa734e1ac0a46fefef6abcc5ad29763003a7e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 01:02:21 -0600 Subject: [PATCH 03/15] Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) --- comfy/hooks.py | 24 +++++++++++++++++------- comfy/model_patcher.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 181c4996..7ca3a8a1 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -44,7 +44,7 @@ class EnumHookType(enum.Enum): Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Wrappers = "wrappers" + TransformerOptions = "transformer_options" Injections = "add_injections" class EnumWeightTarget(enum.Enum): @@ -245,29 +245,39 @@ class AddModelsHook(Hook): if not self.should_register(model, model_options, target_dict, registered): return False -class WrapperHook(Hook): +class TransformerOptionsHook(Hook): ''' - Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = wrappers_dict def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict + c.transformers_dict = self.transformers_dict return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} + add_model_options = {"transformer_options": self.transformers_dict} + # TODO: call .to on patches/anything else in transformer_options that is expected to do something if self.hook_scope == EnumHookScope.AllConditioning: comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +class WrapperHook(TransformerOptionsHook): + ''' + For backwards compatibility, this hook is identical to TransformerOptionsHook. + ''' + pass class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 07153552..2db21bdc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -945,7 +945,7 @@ class ModelPatcher: registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] From 111fd0cadfe83cdda7a1a775f89e0dd675a58d66 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 02:04:07 -0600 Subject: [PATCH 04/15] Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type --- comfy/hooks.py | 78 ++++++++++++++++++------------------- comfy/model_patcher.py | 10 ++--- comfy/sampler_helpers.py | 20 +++++----- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 53 insertions(+), 57 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7ca3a8a1..9ccfaa6d 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -41,7 +41,6 @@ class EnumHookType(enum.Enum): Hook types, each of which has different expected behavior. ''' Weight = "weight" - Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" TransformerOptions = "transformer_options" @@ -194,19 +193,6 @@ class WeightHook(Hook): c._strength_clip = self._strength_clip return c -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - class ObjectPatchHook(Hook): def __init__(self): super().__init__(hook_type=EnumHookType.ObjectPatch) @@ -244,6 +230,7 @@ class AddModelsHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False + return True class TransformerOptionsHook(Hook): ''' @@ -298,12 +285,28 @@ class SetInjectionsHook(Hook): pass class HookGroup: + ''' + Stores groups of hooks, and allows them to be queried by type. + + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; + always use the provided functions on HookGroup. + ''' def __init__(self): self.hooks: list[Hook] = [] + self._hook_dict: dict[EnumHookType, list[Hook]] = {} def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) + self._hook_dict.setdefault(hook.hook_type, []).append(hook) + + def remove(self, hook: Hook): + if hook in self.hooks: + self.hooks.remove(hook) + self._hook_dict[hook.hook_type].remove(hook) + + def get_type(self, hook_type: EnumHookType): + return self._hook_dict.get(hook_type, []) def contains(self, hook: Hook): return hook in self.hooks @@ -329,36 +332,29 @@ class HookGroup: for hook in self.hooks: hook.hook_keyframe = hook_kf - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - def get_hooks_for_clip_schedule(self): scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + # only care about WeightHooks, for now + for hook in self.get_type(EnumHookType.Weight): + hook: WeightHook + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule # hooks should not have their schedules in a list of tuples all_ranges: list[tuple[float, float]] = [] for range_kfs in scheduled_hooks.values(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2db21bdc..0430430e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,16 +940,16 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided + # handle TransformerOptionsHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) if len(weight_hooks_to_register) > 0: @@ -958,7 +958,7 @@ class ModelPatcher: for hook in weight_hooks_to_register: hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target_dict) + callback(self, hooks, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 6f21ca3c..abd44cf6 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): +def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup): # get hooks from conds, and collect cnets so they can be checked for extra_hooks cnets: list[ControlBase] = [] for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) if 'control' in c: cnets.append(c['control']) @@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) if extra_hooks is not None: for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) - return hooks_dict + return full_hooks def convert_cond(cond): out = [] @@ -73,7 +70,7 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} + hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") @@ -90,7 +87,10 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] + hook_models = [] + for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + x: comfy.hooks.AddModelsHook + hook_models.extend(x.models) models = control_models + gligen + add_models + hook_models return models, inference_memory @@ -124,7 +124,7 @@ def cleanup_models(conds, models): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # check for hooks in conds - if not registered, see if can be applied - hooks = {} + hooks = comfy.hooks.HookGroup() for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 49b90b9d..64223834 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 6620d86318d19562a4410eabc78c27538d54e445 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 15:26:22 -0600 Subject: [PATCH 05/15] In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch --- comfy/hooks.py | 2 +- comfy/samplers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090b..3cb0f396 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -442,7 +442,7 @@ class HookKeyframeGroup: return False if curr_t == self._curr_t: return False - max_sigma = torch.max(transformer_options["sigmas"]) + max_sigma = torch.max(transformer_options["sample_sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch diff --git a/comfy/samplers.py b/comfy/samplers.py index 89464a42..af2b8e11 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -849,7 +849,7 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) - extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( From 8270ff312f7aefc4d29aeeed667296b2a56628ce Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 21:07:02 -0600 Subject: [PATCH 06/15] Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time --- comfy/hooks.py | 34 +++++++++++++++--------- comfy/model_patcher.py | 15 +++++------ comfy/sampler_helpers.py | 48 +++++++++++++++++++++++++-------- comfy/samplers.py | 57 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 35 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 3ead8c96..25d67b86 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -65,7 +65,7 @@ class _HookRef: pass -def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): '''Example for how should_register function should look like.''' return True @@ -114,10 +114,10 @@ class Hook: c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): @@ -154,7 +154,7 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False weights = None @@ -178,7 +178,7 @@ class WeightHook(Hook): else: weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) + registered.add(self) return True # TODO: add logs about any keys that were not applied @@ -212,11 +212,12 @@ class AddModelsHook(Hook): Note, value of hook_scope is ignored and is treated as AllConditioning. ''' - def __init__(self, key: str=None, models: list[ModelPatcher]=None): + def __init__(self, models: list[ModelPatcher]=None, key: str=None): super().__init__(hook_type=EnumHookType.AddModels) - self.key = key self.models = models + self.key = key self.append_when_same = True + '''Curently does nothing.''' def clone(self, subtype: Callable=None): if subtype is None: @@ -227,9 +228,10 @@ class AddModelsHook(Hook): c.append_when_same = self.append_when_same return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False + registered.add(self) return True class TransformerOptionsHook(Hook): @@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook): c.transformers_dict = self.transformers_dict return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.transformers_dict} - # TODO: call .to on patches/anything else in transformer_options that is expected to do something + # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks if self.hook_scope == EnumHookScope.AllConditioning: - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) + add_model_options = {"transformer_options": self.transformers_dict, + "to_load_options": self.transformers_dict} + else: + add_model_options = {"to_load_options": self.transformers_dict} + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -295,6 +300,9 @@ class HookGroup: self.hooks: list[Hook] = [] self._hook_dict: dict[EnumHookType, list[Hook]] = {} + def __len__(self): + return len(self.hooks) + def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0430430e..2a551087 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,11 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, + registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle TransformerOptionsHooks, if model_options provided - if model_options is not None: - for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + if registered is None: + registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): @@ -956,9 +954,10 @@ class ModelPatcher: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks, target_dict) + callback(self, hooks, target_dict, model_options, registered) + return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index abd44cf6..cb938851 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -70,13 +70,11 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -87,14 +85,20 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [] - for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - x: comfy.hooks.AddModelsHook - hook_models.extend(x.models) - models = control_models + gligen + add_models + hook_models + models = control_models + gligen + add_models return models, inference_memory +def get_additional_models_from_model_options(model_options: dict[str]=None): + """loads additional models from registered AddModels hooks""" + models = [] + if model_options is not None and "registered_hooks" in model_options: + registered: comfy.hooks.HookGroup = model_options["registered_hooks"] + for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + models.extend(hook.models) + return models + def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: @@ -102,9 +106,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): - real_model: 'BaseModel' = None +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory @@ -130,5 +135,26 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # add wrappers and callbacks from ModelPatcher to transformer_options model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) + # begin registering hooks + registered = comfy.hooks.HookGroup() + target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) + # handle all TransformerOptionsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): + hook: comfy.hooks.TransformerOptionsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all AddModelsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all WeightHooks by registering on ModelPatcher + model.register_all_hook_patches(hooks, target_dict, model_options, registered) + # add registered_hooks onto model_options for further reference + if len(registered) > 0: + model_options["registered_hooks"] = registered + # merge original wrappers and callbacks with hooked wrappers and callbacks + to_load_options: dict[str] = model_options.setdefault("to_load_options", {}) + for wc_name in ["wrappers", "callbacks"]: + comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], + copy_dict1=False) + return to_load_options + diff --git a/comfy/samplers.py b/comfy/samplers.py index af2b8e11..8f8345ab 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): return len(hooks_set) +def cast_to_load_options(model_options: dict[str], device=None, dtype=None): + ''' + If any patches from hooks, wrappers, or callbacks have .to to be called, call it. + ''' + if model_options is None: + return + to_load_options = model_options.get("to_load_options", None) + if to_load_options is None: + return + + casts = [] + if device is not None: + casts.append(device) + if dtype is not None: + casts.append(dtype) + # if nothing to apply, do nothing + if len(casts) == 0: + return + + # Try to call .to on patches + if "patches" in to_load_options: + patches = to_load_options["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + for cast in casts: + patch_list[i] = patch_list[i].to(cast) + if "patches_replace" in to_load_options: + patches = to_load_options["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + for cast in casts: + patch_list[k] = patch_list[k].to(cast) + # Try to call .to on any wrappers/callbacks + wrappers_and_callbacks = ["wrappers", "callbacks"] + for wc_name in wrappers_and_callbacks: + if wc_name in to_load_options: + wc: dict[str, list] = to_load_options[wc_name] + for wc_dict in wc.values(): + for wc_list in wc_dict.values(): + for i in range(len(wc_list)): + if hasattr(wc_list[i], "to"): + for cast in casts: + wc_list[i] = wc_list[i].to(cast) + + class CFGGuider: - def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + def __init__(self, model_patcher: ModelPatcher): + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -861,7 +910,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device if denoise_mask is not None: @@ -870,6 +919,7 @@ class CFGGuider: noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) try: self.model_patcher.pre_run() @@ -906,6 +956,7 @@ class CFGGuider: ) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: + cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() From 4446c86052bd9a00b72205b761b3744dd51f90eb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 22:25:51 -0600 Subject: [PATCH 07/15] Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational --- comfy/hooks.py | 64 +++++++++++++++++++--------------------- comfy/sampler_helpers.py | 1 - 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 25d67b86..b62092cc 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -101,10 +101,8 @@ class Hook: def reset(self): self.hook_keyframe.reset() - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() + def clone(self): + c: Hook = self.__class__() c.hook_type = self.hook_type c.hook_ref = self.hook_ref c.hook_id = self.hook_id @@ -182,10 +180,8 @@ class WeightHook(Hook): return True # TODO: add logs about any keys that were not applied - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) + def clone(self): + c: WeightHook = super().clone() c.weights = self.weights c.weights_clip = self.weights_clip c.need_weight_init = self.need_weight_init @@ -194,17 +190,21 @@ class WeightHook(Hook): return c class ObjectPatchHook(Hook): - def __init__(self): + def __init__(self, object_patches: dict[str]=None): super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None + self.object_patches = object_patches - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) + def clone(self): + c: ObjectPatchHook = super().clone() c.object_patches = self.object_patches return c - # TODO: add functionality + + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True class AddModelsHook(Hook): ''' @@ -219,12 +219,10 @@ class AddModelsHook(Hook): self.append_when_same = True '''Curently does nothing.''' - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key + def clone(self): + c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models + c.key = self.key c.append_when_same = self.append_when_same return c @@ -242,10 +240,8 @@ class TransformerOptionsHook(Hook): super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = wrappers_dict - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) + def clone(self): + c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict return c @@ -265,11 +261,8 @@ class TransformerOptionsHook(Hook): def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) -class WrapperHook(TransformerOptionsHook): - ''' - For backwards compatibility, this hook is identical to TransformerOptionsHook. - ''' - pass +WrapperHook = TransformerOptionsHook +'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): @@ -277,14 +270,19 @@ class SetInjectionsHook(Hook): self.key = key self.injections = injections - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) + def clone(self): + c: SetInjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index cb938851..d43280fe 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -157,4 +157,3 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options - From 03a97b604a3e8ca9f54c711ed3b007f07c9115ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 01:03:59 -0600 Subject: [PATCH 08/15] Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) --- comfy/hooks.py | 33 +++++++++++++++++++++++++-------- comfy/model_patcher.py | 10 ++++++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index b62092cc..dde3e8bc 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -317,6 +317,18 @@ class HookGroup: def contains(self, hook: Hook): return hook in self.hooks + def is_subset_of(self, other: HookGroup): + self_hooks = set(self.hooks) + other_hooks = set(other.hooks) + return self_hooks.issubset(other_hooks) + + def new_with_common_hooks(self, other: HookGroup): + c = HookGroup() + for hook in self.hooks: + if other.contains(hook): + c.add(hook.clone()) + return c + def clone(self): c = HookGroup() for hook in self.hooks: @@ -668,24 +680,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H else: c_dict[hooks_key] = cache[hooks_tuple] -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True, + cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + if cache is None: + cache = {} for t in conditioning: n = [t[0], t[1].copy()] for k in values: if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) + _combine_hooks_from_values(n[1], values, cache) else: n[1][k] = values[k] c.append(n) return c -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): if hooks is None: return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache) def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): if timestep_range is None: @@ -720,9 +734,10 @@ def combine_with_new_conds(conds: list, new_conds: list): def set_conds_props(conds: list, strength: float, set_cond_area: str, mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): final_conds = [] + cache = {} for c in conds: # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to conditioning c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) # apply timesteps, if present @@ -734,9 +749,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, masked_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to new conditioning, if provided masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) # apply timesteps, if present @@ -748,9 +764,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1. def set_default_conds_and_combine(conds: list, new_conds: list, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, new_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache) # next, add default_cond key to cond so that during sampling, it can be identified new_c = conditioning_set_values(new_c, {'default': True}) # apply timesteps, if present diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2a551087..57a843b8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,7 +210,7 @@ class ModelPatcher: self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None @@ -282,7 +282,7 @@ class ModelPatcher: n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: @@ -912,9 +912,9 @@ class ModelPatcher: callback(self, timestep) def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: + if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} + self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode @@ -950,6 +950,8 @@ class ModelPatcher: for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) + else: + registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) From 0a7e2ae787b81035798ad2ef1ade8cf882d67b69 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 01:04:29 -0600 Subject: [PATCH 09/15] Filter only registered hooks on self.conds in CFGGuider.sample --- comfy/sampler_helpers.py | 3 +++ comfy/samplers.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index d43280fe..1433d185 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -128,6 +128,9 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + ''' + Registers hooks from conds. + ''' # check for hooks in conds - if not registered, see if can be applied hooks = comfy.hooks.HookGroup() for k in conds: diff --git a/comfy/samplers.py b/comfy/samplers.py index 8f8345ab..43a735c6 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): for cond in conds_to_modify: cond['hooks'] = hooks +def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]): + '''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for + HookGroups that have the same reference.''' + registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None) + # if None were registered, make sure all hooks are cleaned from conds + if registered is None: + for k in conds: + for kk in conds[k]: + kk.pop('hooks', None) + return + # find conds that contain hooks to be replaced - group by common HookGroup refs + hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {} + for k in conds: + for kk in conds[k]: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + if hooks is not None: + if not hooks.is_subset_of(registered): + to_replace = hook_replacement.setdefault(hooks, []) + to_replace.append(kk) + # for each hook to replace, create a new proper HookGroup and assign to all common conds + for hooks, conds_to_modify in hook_replacement.items(): + new_hooks = hooks.new_with_common_hooks(registered) + if len(new_hooks) == 0: + new_hooks = None + for kk in conds_to_modify: + kk['hooks'] = new_hooks + def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): hooks_set = set() @@ -949,6 +976,7 @@ class CFGGuider: if get_total_hook_groups_in_conds(self.conds) <= 1: self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + filter_registered_hooks_on_conds(self.conds, self.model_options) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( self.outer_sample, self, From f48f90e471fc5440135e7886d712518467c59c00 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 02:23:04 -0600 Subject: [PATCH 10/15] Make hook_scope functional for TransformerOptionsHook --- comfy/hooks.py | 41 ++++++++++++++++++++++++++--------------- comfy/model_patcher.py | 4 ++-- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index dde3e8bc..cc9f6cd5 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -86,9 +86,9 @@ class Hook: self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + self.hook_scope = hook_scope self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False - self.hook_scope = hook_scope @property def strength(self): @@ -107,6 +107,7 @@ class Hook: c.hook_ref = self.hook_ref c.hook_id = self.hook_id c.hook_keyframe = self.hook_keyframe + c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register # TODO: make this do something c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive @@ -118,12 +119,6 @@ class Hook: def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - - def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref @@ -143,6 +138,7 @@ class WeightHook(Hook): self.need_weight_init = True self._strength_model = strength_model self._strength_clip = strength_clip + self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs @property def strength_model(self): @@ -190,9 +186,11 @@ class WeightHook(Hook): return c class ObjectPatchHook(Hook): - def __init__(self, object_patches: dict[str]=None): + def __init__(self, object_patches: dict[str]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.ObjectPatch) self.object_patches = object_patches + self.hook_scope = hook_scope def clone(self): c: ObjectPatchHook = super().clone() @@ -216,14 +214,11 @@ class AddModelsHook(Hook): super().__init__(hook_type=EnumHookType.AddModels) self.models = models self.key = key - self.append_when_same = True - '''Curently does nothing.''' def clone(self): c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key - c.append_when_same = self.append_when_same return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook): ''' Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): + def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.TransformerOptions) - self.transformers_dict = wrappers_dict + self.transformers_dict = transformers_dict + self.hook_scope = hook_scope def clone(self): c: TransformerOptionsHook = super().clone() @@ -254,8 +251,9 @@ class TransformerOptionsHook(Hook): "to_load_options": self.transformers_dict} else: add_model_options = {"to_load_options": self.transformers_dict} + # only register if will not be included in AllConditioning to avoid double loading + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -265,10 +263,12 @@ WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + def __init__(self, key: str=None, injections: list[PatcherInjection]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections + self.hook_scope = hook_scope def clone(self): c: SetInjectionsHook = super().clone() @@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: sorted_list.extend(object_list) return sorted_list +def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None): + # if no hooks or is not a ModelPatcher for sampling, return empty dict + if hooks is None or model.is_clip: + return {} + if transformer_options is None: + transformer_options = {} + for hook in hooks.get_type(EnumHookType.TransformerOptions): + hook: TransformerOptionsHook + hook.on_apply_hooks(model, transformer_options) + return transformer_options + def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): hook_group = HookGroup() hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57a843b8..51a62e04 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1010,11 +1010,11 @@ class ModelPatcher: def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): From 1b38f5bf57ca07490e616dd58ec3004d05de0655 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 17:11:12 -0600 Subject: [PATCH 11/15] removed 4 whitespace lines to satisfy Ruff, --- comfy/hooks.py | 4 ++-- comfy/samplers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cc9f6cd5..46fc06bd 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -255,7 +255,7 @@ class TransformerOptionsHook(Hook): registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True - + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) @@ -290,7 +290,7 @@ class SetInjectionsHook(Hook): class HookGroup: ''' Stores groups of hooks, and allows them to be queried by type. - + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; always use the provided functions on HookGroup. ''' diff --git a/comfy/samplers.py b/comfy/samplers.py index 43a735c6..a725d518 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -855,7 +855,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return - + casts = [] if device is not None: casts.append(device) @@ -864,7 +864,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - + # Try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] From 58bf8815c84b67ab26b0f08b8530a822b9899b10 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 20:34:30 -0600 Subject: [PATCH 12/15] Add a get_injections function to ModelPatcher --- comfy/model_patcher.py | 3 +++ comfy/samplers.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 51a62e04..7d7977c1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -842,6 +842,9 @@ class ModelPatcher: if key in self.injections: self.injections.pop(key) + def get_injections(self, key: str): + return self.injections.get(key, None) + def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models diff --git a/comfy/samplers.py b/comfy/samplers.py index a725d518..5cc33a7d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -865,7 +865,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if len(casts) == 0: return - # Try to call .to on patches + # try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] for name in patches: @@ -882,7 +882,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if hasattr(patch_list[k], "to"): for cast in casts: patch_list[k] = patch_list[k].to(cast) - # Try to call .to on any wrappers/callbacks + # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: if wc_name in to_load_options: From 216fea15ee033d3301241a5ceb0e193b4924de04 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 00:59:18 -0600 Subject: [PATCH 13/15] Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable --- comfy/hooks.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 46fc06bd..7c2f6689 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -66,7 +66,7 @@ class _HookRef: def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - '''Example for how should_register function should look like.''' + '''Example for how custom_should_register function can look like.''' return True @@ -83,12 +83,17 @@ class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type + '''Enum identifying the general class of this hook.''' self.hook_ref = hook_ref if hook_ref else _HookRef() + '''Reference shared between hook clones that have the same value. Should NOT be modified.''' self.hook_id = hook_id + '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.''' self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + '''Keyframe storage that can be referenced to get strength for current sampling step.''' self.hook_scope = hook_scope + '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False + '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): @@ -109,8 +114,6 @@ class Hook: c.hook_keyframe = self.hook_keyframe c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,28 +239,34 @@ class TransformerOptionsHook(Hook): super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = transformers_dict self.hook_scope = hook_scope + self._skip_adding = False + '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.''' def clone(self): c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict + c._skip_adding = self._skip_adding return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks + self._skip_adding = False if self.hook_scope == EnumHookScope.AllConditioning: add_model_options = {"transformer_options": self.transformers_dict, "to_load_options": self.transformers_dict} + # skip_adding if included in AllConditioning to avoid double loading + self._skip_adding = True else: add_model_options = {"to_load_options": self.transformers_dict} - # only register if will not be included in AllConditioning to avoid double loading - registered.add(self) + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): - comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + if not self._skip_adding: + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' From 3cd4c5cb0a9d4f4f944ee1382e074d3a41e18874 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 02:22:49 -0600 Subject: [PATCH 14/15] Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) --- comfy/hooks.py | 26 +++++++------------------- comfy/sampler_helpers.py | 8 ++++---- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7c2f6689..9d073107 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -42,7 +42,7 @@ class EnumHookType(enum.Enum): ''' Weight = "weight" ObjectPatch = "object_patch" - AddModels = "add_models" + AdditionalModels = "add_models" TransformerOptions = "transformer_options" Injections = "add_injections" @@ -202,24 +202,20 @@ class ObjectPatchHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True -class AddModelsHook(Hook): +class AdditionalModelsHook(Hook): ''' Hook responsible for telling model management any additional models that should be loaded. Note, value of hook_scope is ignored and is treated as AllConditioning. ''' def __init__(self, models: list[ModelPatcher]=None, key: str=None): - super().__init__(hook_type=EnumHookType.AddModels) + super().__init__(hook_type=EnumHookType.AdditionalModels) self.models = models self.key = key def clone(self): - c: AddModelsHook = super().clone() + c: AdditionalModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key return c @@ -271,7 +267,7 @@ class TransformerOptionsHook(Hook): WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' -class SetInjectionsHook(Hook): +class InjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None, hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) @@ -280,21 +276,13 @@ class SetInjectionsHook(Hook): self.hook_scope = hook_scope def clone(self): - c: SetInjectionsHook = super().clone() + c: InjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True - - def add_hook_injections(self, model: ModelPatcher): - # TODO: add functionality - pass + raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.") class HookGroup: ''' diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1433d185..b70e5e63 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -94,8 +94,8 @@ def get_additional_models_from_model_options(model_options: dict[str]=None): models = [] if model_options is not None and "registered_hooks" in model_options: registered: comfy.hooks.HookGroup = model_options["registered_hooks"] - for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook models.extend(hook.models) return models @@ -146,8 +146,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): hook: comfy.hooks.TransformerOptionsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all AddModelsHooks - for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all WeightHooks by registering on ModelPatcher model.register_all_hook_patches(hooks, target_dict, model_options, registered) From 733328169868b9f4120cbfc59af2b00683df8563 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 02:58:59 -0600 Subject: [PATCH 15/15] Clean up a typehint --- comfy_extras/nodes_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 64223834..1edc06f3 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -246,7 +246,7 @@ class SetClipHooks: CATEGORY = "advanced/hooks/clip" FUNCTION = "apply_hooks" - def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: clip = clip.clone() if apply_to_conds: