diff --git a/comfy/hooks.py b/comfy/hooks.py index 3cb0f396..9d073107 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -16,91 +16,132 @@ import comfy.model_management import comfy.patcher_extension from node_helpers import conditioning_set_values +# ####################################################################################################### +# Hooks explanation +# ------------------- +# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to +# make explicit special cases like it does for ControlNet and GLIGEN. +# +# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those +# that should run special code when a 'marked' cond is used in sampling. +# ####################################################################################################### + class EnumHookMode(enum.Enum): + ''' + Priority of hook memory optimization vs. speed, mostly related to WeightHooks. + + MinVram: No caching will occur for any operations related to hooks. + MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups. + ''' MinVram = "minvram" MaxSpeed = "maxspeed" class EnumHookType(enum.Enum): + ''' + Hook types, each of which has different expected behavior. + ''' Weight = "weight" - Patch = "patch" ObjectPatch = "object_patch" - AddModels = "add_models" - Callbacks = "callbacks" - Wrappers = "wrappers" - SetInjections = "add_injections" + AdditionalModels = "add_models" + TransformerOptions = "transformer_options" + Injections = "add_injections" class EnumWeightTarget(enum.Enum): Model = "model" Clip = "clip" +class EnumHookScope(enum.Enum): + ''' + Determines if hook should be limited in its influence over sampling. + + AllConditioning: hook will affect all conds used in sampling. + HookedOnly: hook will only affect the conds it was attached to. + ''' + AllConditioning = "all_conditioning" + HookedOnly = "hooked_only" + + class _HookRef: pass -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + '''Example for how custom_should_register function can look like.''' return True +def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]: + '''Creates base dictionary for use with Hooks' target param.''' + d = {} + if target is not None: + d['target'] = target + d.update(kwargs) + return d + + class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): + hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type + '''Enum identifying the general class of this hook.''' self.hook_ref = hook_ref if hook_ref else _HookRef() + '''Reference shared between hook clones that have the same value. Should NOT be modified.''' self.hook_id = hook_id + '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.''' self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + '''Keyframe storage that can be referenced to get strength for current sampling step.''' + self.hook_scope = hook_scope + '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False + '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): return self.hook_keyframe.strength - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): self.reset() self.hook_keyframe.initialize_timesteps(model) def reset(self): self.hook_keyframe.reset() - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() + def clone(self): + c: Hook = self.__class__() c.hook_type = self.hook_type c.hook_ref = self.hook_ref c.hook_id = self.hook_id c.hook_keyframe = self.hook_keyframe + c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): - pass - - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): - pass - - def __eq__(self, other: 'Hook'): + def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref def __hash__(self): return hash(self.hook_ref) class WeightHook(Hook): + ''' + Hook responsible for tracking weights to be applied to some model/clip. + + Note, value of hook_scope is ignored and is treated as HookedOnly. + ''' def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) + super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly) self.weights: dict = None self.weights_clip: dict = None self.need_weight_init = True self._strength_model = strength_model self._strength_clip = strength_clip + self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs @property def strength_model(self): @@ -110,36 +151,36 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + if not self.should_register(model, model_options, target_dict, registered): return False weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: + + target = target_dict.get('target', None) + if target == EnumWeightTarget.Clip: strength = self._strength_clip + else: + strength = self._strength_model if self.need_weight_init: key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: + if target == EnumWeightTarget.Clip: key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: + if target == EnumWeightTarget.Clip: weights = self.weights_clip + else: + weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) + registered.add(self) return True # TODO: add logs about any keys that were not applied - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) + def clone(self): + c: WeightHook = super().clone() c.weights = self.weights c.weights_clip = self.weights_clip c.need_weight_init = self.need_weight_init @@ -147,127 +188,158 @@ class WeightHook(Hook): c._strength_clip = self._strength_clip return c -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - class ObjectPatchHook(Hook): - def __init__(self): + def __init__(self, object_patches: dict[str]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None + self.object_patches = object_patches + self.hook_scope = hook_scope - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) + def clone(self): + c: ObjectPatchHook = super().clone() c.object_patches = self.object_patches return c - # TODO: add functionality -class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): - super().__init__(hook_type=EnumHookType.AddModels) - self.key = key + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") + +class AdditionalModelsHook(Hook): + ''' + Hook responsible for telling model management any additional models that should be loaded. + + Note, value of hook_scope is ignored and is treated as AllConditioning. + ''' + def __init__(self, models: list[ModelPatcher]=None, key: str=None): + super().__init__(hook_type=EnumHookType.AdditionalModels) self.models = models - self.append_when_same = True - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key - c.models = self.models.copy() if self.models else self.models - c.append_when_same = self.append_when_same - return c - # TODO: add functionality - -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) self.key = key - self.callback = callback - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) + def clone(self): + c: AdditionalModelsHook = super().clone() + c.models = self.models.copy() if self.models else self.models c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality - -class WrapperHook(Hook): - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict return c - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) + registered.add(self) return True -class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) +class TransformerOptionsHook(Hook): + ''' + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. + ''' + def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None, + hook_scope=EnumHookScope.AllConditioning): + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = transformers_dict + self.hook_scope = hook_scope + self._skip_adding = False + '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.''' + + def clone(self): + c: TransformerOptionsHook = super().clone() + c.transformers_dict = self.transformers_dict + c._skip_adding = self._skip_adding + return c + + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + if not self.should_register(model, model_options, target_dict, registered): + return False + # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks + self._skip_adding = False + if self.hook_scope == EnumHookScope.AllConditioning: + add_model_options = {"transformer_options": self.transformers_dict, + "to_load_options": self.transformers_dict} + # skip_adding if included in AllConditioning to avoid double loading + self._skip_adding = True + else: + add_model_options = {"to_load_options": self.transformers_dict} + registered.add(self) + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + if not self._skip_adding: + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +WrapperHook = TransformerOptionsHook +'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' + +class InjectionsHook(Hook): + def __init__(self, key: str=None, injections: list[PatcherInjection]=None, + hook_scope=EnumHookScope.AllConditioning): + super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections + self.hook_scope = hook_scope - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) + def clone(self): + c: InjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c - def add_hook_injections(self, model: 'ModelPatcher'): - # TODO: add functionality - pass + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.") class HookGroup: + ''' + Stores groups of hooks, and allows them to be queried by type. + + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; + always use the provided functions on HookGroup. + ''' def __init__(self): self.hooks: list[Hook] = [] + self._hook_dict: dict[EnumHookType, list[Hook]] = {} + + def __len__(self): + return len(self.hooks) def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) + self._hook_dict.setdefault(hook.hook_type, []).append(hook) + + def remove(self, hook: Hook): + if hook in self.hooks: + self.hooks.remove(hook) + self._hook_dict[hook.hook_type].remove(hook) + + def get_type(self, hook_type: EnumHookType): + return self._hook_dict.get(hook_type, []) def contains(self, hook: Hook): return hook in self.hooks + def is_subset_of(self, other: HookGroup): + self_hooks = set(self.hooks) + other_hooks = set(other.hooks) + return self_hooks.issubset(other_hooks) + + def new_with_common_hooks(self, other: HookGroup): + c = HookGroup() + for hook in self.hooks: + if other.contains(hook): + c.add(hook.clone()) + return c + def clone(self): c = HookGroup() for hook in self.hooks: c.add(hook.clone()) return c - def clone_and_combine(self, other: 'HookGroup'): + def clone_and_combine(self, other: HookGroup): c = self.clone() if other is not None: for hook in other.hooks: c.add(hook.clone()) return c - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): if hook_kf is None: hook_kf = HookKeyframeGroup() else: @@ -275,36 +347,29 @@ class HookGroup: for hook in self.hooks: hook.hook_keyframe = hook_kf - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - def get_hooks_for_clip_schedule(self): scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + # only care about WeightHooks, for now + for hook in self.get_type(EnumHookType.Weight): + hook: WeightHook + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule # hooks should not have their schedules in a list of tuples all_ranges: list[tuple[float, float]] = [] for range_kfs in scheduled_hooks.values(): @@ -336,7 +401,7 @@ class HookGroup: hook.reset() @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup: actual: list[HookGroup] = [] for group in hooks_list: if group is not None: @@ -433,7 +498,7 @@ class HookKeyframeGroup: c._set_first_as_current() return c - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -522,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: sorted_list.extend(object_list) return sorted_list +def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None): + # if no hooks or is not a ModelPatcher for sampling, return empty dict + if hooks is None or model.is_clip: + return {} + if transformer_options is None: + transformer_options = {} + for hook in hooks.get_type(EnumHookType.TransformerOptions): + hook: TransformerOptionsHook + hook.on_apply_hooks(model, transformer_options) + return transformer_options + def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): hook_group = HookGroup() hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) @@ -548,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float hook.need_weight_init = False return hook_group -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): +def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True): if model is None: return None patches_model: dict[str, torch.Tensor] = model.model.state_dict() @@ -560,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T return patches_model # NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], +def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): key_map = {} if model is not None: @@ -612,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H else: c_dict[hooks_key] = cache[hooks_tuple] -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True, + cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + if cache is None: + cache = {} for t in conditioning: n = [t[0], t[1].copy()] for k in values: if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) + _combine_hooks_from_values(n[1], values, cache) else: n[1][k] = values[k] c.append(n) return c -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): if hooks is None: return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache) def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): if timestep_range is None: @@ -664,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list): def set_conds_props(conds: list, strength: float, set_cond_area: str, mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): final_conds = [] + cache = {} for c in conds: # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to conditioning c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) # apply timesteps, if present @@ -678,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, masked_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to new conditioning, if provided masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) # apply timesteps, if present @@ -692,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1. def set_default_conds_and_combine(conds: list, new_conds: list, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, new_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache) # next, add default_cond key to cond so that during sampling, it can be identified new_c = conditioning_set_values(new_c, {'default': True}) # apply timesteps, if present diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e886bdbb..0501f7b3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,7 +210,7 @@ class ModelPatcher: self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None @@ -282,7 +282,7 @@ class ModelPatcher: n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: @@ -855,6 +855,9 @@ class ModelPatcher: if key in self.injections: self.injections.pop(key) + def get_injections(self, key: str): + return self.injections.get(key, None) + def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models @@ -925,9 +928,9 @@ class ModelPatcher: callback(self, timestep) def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: + if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} + self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode @@ -953,25 +956,26 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, + registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided - if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) + if registered is None: + registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) + else: + registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) + callback(self, hooks, target_dict, model_options, registered) + return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): @@ -1022,11 +1026,11 @@ class ModelPatcher: def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index ac973536..b70e5e63 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): +def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup): # get hooks from conds, and collect cnets so they can be checked for extra_hooks cnets: list[ControlBase] = [] for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) if 'control' in c: cnets.append(c['control']) @@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) if extra_hooks is not None: for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) - return hooks_dict + return full_hooks def convert_cond(cond): out = [] @@ -73,13 +70,11 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -90,11 +85,20 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] - models = control_models + gligen + add_models + hook_models + models = control_models + gligen + add_models return models, inference_memory +def get_additional_models_from_model_options(model_options: dict[str]=None): + """loads additional models from registered AddModels hooks""" + models = [] + if model_options is not None and "registered_hooks" in model_options: + registered: comfy.hooks.HookGroup = model_options["registered_hooks"] + for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook + models.extend(hook.models) + return models + def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: @@ -102,9 +106,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): - real_model: 'BaseModel' = None +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory @@ -123,12 +128,35 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + ''' + Registers hooks from conds. + ''' # check for hooks in conds - if not registered, see if can be applied - hooks = {} + hooks = comfy.hooks.HookGroup() for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) + # begin registering hooks + registered = comfy.hooks.HookGroup() + target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) + # handle all TransformerOptionsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): + hook: comfy.hooks.TransformerOptionsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all AddModelsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all WeightHooks by registering on ModelPatcher + model.register_all_hook_patches(hooks, target_dict, model_options, registered) + # add registered_hooks onto model_options for further reference + if len(registered) > 0: + model_options["registered_hooks"] = registered + # merge original wrappers and callbacks with hooked wrappers and callbacks + to_load_options: dict[str] = model_options.setdefault("to_load_options", {}) + for wc_name in ["wrappers", "callbacks"]: + comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], + copy_dict1=False) + return to_load_options diff --git a/comfy/samplers.py b/comfy/samplers.py index af2b8e11..5cc33a7d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): for cond in conds_to_modify: cond['hooks'] = hooks +def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]): + '''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for + HookGroups that have the same reference.''' + registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None) + # if None were registered, make sure all hooks are cleaned from conds + if registered is None: + for k in conds: + for kk in conds[k]: + kk.pop('hooks', None) + return + # find conds that contain hooks to be replaced - group by common HookGroup refs + hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {} + for k in conds: + for kk in conds[k]: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + if hooks is not None: + if not hooks.is_subset_of(registered): + to_replace = hook_replacement.setdefault(hooks, []) + to_replace.append(kk) + # for each hook to replace, create a new proper HookGroup and assign to all common conds + for hooks, conds_to_modify in hook_replacement.items(): + new_hooks = hooks.new_with_common_hooks(registered) + if len(new_hooks) == 0: + new_hooks = None + for kk in conds_to_modify: + kk['hooks'] = new_hooks + def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): hooks_set = set() @@ -819,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): return len(hooks_set) +def cast_to_load_options(model_options: dict[str], device=None, dtype=None): + ''' + If any patches from hooks, wrappers, or callbacks have .to to be called, call it. + ''' + if model_options is None: + return + to_load_options = model_options.get("to_load_options", None) + if to_load_options is None: + return + + casts = [] + if device is not None: + casts.append(device) + if dtype is not None: + casts.append(dtype) + # if nothing to apply, do nothing + if len(casts) == 0: + return + + # try to call .to on patches + if "patches" in to_load_options: + patches = to_load_options["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + for cast in casts: + patch_list[i] = patch_list[i].to(cast) + if "patches_replace" in to_load_options: + patches = to_load_options["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + for cast in casts: + patch_list[k] = patch_list[k].to(cast) + # try to call .to on any wrappers/callbacks + wrappers_and_callbacks = ["wrappers", "callbacks"] + for wc_name in wrappers_and_callbacks: + if wc_name in to_load_options: + wc: dict[str, list] = to_load_options[wc_name] + for wc_dict in wc.values(): + for wc_list in wc_dict.values(): + for i in range(len(wc_list)): + if hasattr(wc_list[i], "to"): + for cast in casts: + wc_list[i] = wc_list[i].to(cast) + + class CFGGuider: - def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + def __init__(self, model_patcher: ModelPatcher): + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -861,7 +937,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device if denoise_mask is not None: @@ -870,6 +946,7 @@ class CFGGuider: noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) try: self.model_patcher.pre_run() @@ -899,6 +976,7 @@ class CFGGuider: if get_total_hook_groups_in_conds(self.conds) <= 1: self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + filter_registered_hooks_on_conds(self.conds, self.model_options) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( self.outer_sample, self, @@ -906,6 +984,7 @@ class CFGGuider: ) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: + cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 9d9d4837..1edc06f3 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -246,7 +246,7 @@ class SetClipHooks: CATEGORY = "advanced/hooks/clip" FUNCTION = "apply_hooks" - def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: clip = clip.clone() if apply_to_conds: @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: