ModelPatcher Overhaul and Hook Support (#5583)

* Added hook_patches to ModelPatcher for weights (model)

* Initial changes to calc_cond_batch to eventually support hook_patches

* Added current_patcher property to BaseModel

* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature

* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch

* Added default_conds support in calc_cond_batch func

* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring

* Made CLIP work with hook patches

* Added initial hook scheduling nodes, small renaming/refactoring

* Fixed MaxSpeed and default conds implementations

* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time

* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node

* Initial work on adding 'model_as_lora' lora type to calculate_weight

* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models

* Fix incorrect ref to create_hook_patches_clone after moving function

* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks

* Added wrappers to ModelPatcher to facilitate standardized function wrapping

* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type

* Fix skip_until_exit logic bug breaking injection after first run of model

* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load

* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers

* Refactored callbacks+wrappers to allow storing lists by id

* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks

* Added get_attachment func on ModelPatcher

* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup

* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref

* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations

* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)

* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)

* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)

* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling

* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances

* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks

* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge

* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes

* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges

* Optimized CLIP hook scheduling to treat same strength as same keyframe

* Less fragile memory management.

* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict

* Fix issue.

* Remove useless function.

* Prevent and detect some types of memory leaks.

* Run garbage collector when switching workflow if needed.

* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py

* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers

* Fix issue.

* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming

* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options

* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired

* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality

* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP

* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky

* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)

* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher

* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate

* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True

* Fix cached_hook_patches not respecting target_device/memory_counter results

* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches

* Remove unnecessary torch.no_grad calls for hook patches

* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists

* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes

* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call

* Made encode_from_tokens_scheduled work when no hooks are set on patcher

* Small cleanup of comments

* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case

* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs

* Allow both FLOATS and FLOAT for floats_strength input

* Revert change, does not work

* Made patch_hook_weight_to_device respect set_func and convert_func

* Make discard_model_sampling True by default

* Add changes manually from 'master' so merge conflict resolution goes more smoothly

* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call

* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip

* Made nodes in nodes_hooks be marked as experimental (beta)

* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references

* Made finalize_default_conds area math consistent with other sampling code

* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'

* Remove a couple old TODO's and a no longer necessary workaround
This commit is contained in:
Jedrzej Kosinski 2024-12-02 13:51:02 -06:00 committed by GitHub
parent 79d5ceae6e
commit 0ee322ec5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 2692 additions and 304 deletions

View File

@ -36,6 +36,10 @@ import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0] current_batch_size = tensor.shape[0]
@ -78,6 +82,7 @@ class ControlBase:
self.concat_mask = False self.concat_mask = False
self.extra_concat_orig = [] self.extra_concat_orig = []
self.extra_concat = None self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
@ -116,6 +121,14 @@ class ControlBase:
out += self.previous_controlnet.get_models() out += self.previous_controlnet.get_models()
return out return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c): def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
@ -130,6 +143,7 @@ class ControlBase:
c.strength_type = self.strength_type c.strength_type = self.strength_type
c.concat_mask = self.concat_mask c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy() c.extra_concat_orig = self.extra_concat_orig.copy()
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
c.preprocess_image = self.preprocess_image c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
@ -200,10 +214,10 @@ class ControlNet(ControlBase):
self.concat_mask = concat_mask self.concat_mask = concat_mask
self.preprocess_image = preprocess_image self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None: if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -758,10 +772,10 @@ class T2IAdapter(ControlBase):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height return width, height
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None: if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:

690
comfy/hooks.py Normal file
View File

@ -0,0 +1,690 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
class EnumHookMode(enum.Enum):
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return True
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
strength = self._strength_clip
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
weights = self.weights_clip
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
c._strength_model = self._strength_model
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
class HookGroup:
def __init__(self):
self.hooks: list[Hook] = []
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
def contains(self, hook: Hook):
return hook in self.hooks
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
for t_range, keyframe in range_kfs:
all_ranges.append(t_range)
# turn list of ranges into boundaries
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
boundaries_set.add(0.0)
boundaries = sorted(boundaries_set)
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
# with real ranges defined, give appropriate hooks w/ keyframes for each range
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
for t_range in real_ranges:
hooks_schedule = []
for hook, val in scheduled_hooks.items():
keyframe = None
# check if is a keyframe that works for the current t_range
for stored_range, stored_kf in val:
# if stored start is less than current end, then fits - give it assigned keyframe
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
keyframe = stored_kf
break
hooks_schedule.append((hook, keyframe))
scheduled_keyframes.append((t_range, hooks_schedule))
return scheduled_keyframes
def reset(self):
for hook in self.hooks:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class HookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class HookKeyframeGroup:
def __init__(self):
self.keyframes: list[HookKeyframe] = []
self._current_keyframe: HookKeyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self._curr_t = -1.
# properties shadow those of HookWeightsKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_strength = eval_c.strength
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index and prev_strength != self._current_strength
class InterpolationMethod:
LINEAR = "linear"
EASE_IN = "ease_in"
EASE_OUT = "ease_out"
EASE_IN_OUT = "ease_in_out"
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
@classmethod
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
diff = num_to - num_from
if method == cls.LINEAR:
weights = torch.linspace(num_from, num_to, length)
elif method == cls.EASE_IN:
index = torch.linspace(0, 1, length)
weights = diff * np.power(index, 2) + num_from
elif method == cls.EASE_OUT:
index = torch.linspace(0, 1, length)
weights = diff * (1 - np.power(1 - index, 2)) + num_from
elif method == cls.EASE_IN_OUT:
index = torch.linspace(0, 1, length)
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
else:
raise ValueError(f"Unrecognized interpolation method '{method}'.")
if reverse:
weights = weights.flip(dims=(0,))
return weights
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list: list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
hook.weights = lora
return hook_group
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
patches_model = None
patches_clip = None
if weights_model is not None:
patches_model = {}
for key in weights_model:
patches_model[key] = ("model_as_lora", (weights_model[key],))
if weights_clip is not None:
patches_clip = {}
for key in weights_clip:
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
hook.weights = patches_model
hook.weights_clip = patches_clip
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
if discard_model_sampling:
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
if key.startswith("model_sampling"):
patches_model.pop(key, None)
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
hook_group = HookGroup()
hook = WeightHook()
hook_group.add(hook)
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
return cond
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
"end_percent": timestep_range[1]})
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
if mask is None:
return cond
set_area_to_bounds = False
if set_cond_area != 'default':
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(cond, {'mask': mask,
'set_area_to_bounds': set_area_to_bounds,
'mask_strength': strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def combine_with_new_conds(conds: list, new_conds: list):
combined_conds = []
for c, new_c in zip(conds, new_conds):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
# finally, apply mask to conditioning and store
final_conds.append(c)
return final_conds
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds

View File

@ -15,6 +15,7 @@ from .util import (
) )
from ..attention import SpatialTransformer, SpatialVideoTransformer, default from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.patcher_extension
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample): elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape) x = layer(x, output_shape=output_shape)
else: else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x) x = layer(x)
return x return x
@ -819,6 +829,13 @@ class UNetModel(nn.Module):
) )
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.

View File

@ -33,7 +33,7 @@ LORA_CLIP_MAP = {
} }
def load_lora(lora, to_load): def load_lora(lora, to_load, log_missing=True):
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -213,9 +213,10 @@ def load_lora(lora, to_load):
patch_dict[to_load[x]] = ("set", (set_weight,)) patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name) loaded_keys.add(set_weight_name)
for x in lora.keys(): if log_missing:
if x not in loaded_keys: for x in lora.keys():
logging.warning("lora key not loaded: {}".format(x)) if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return patch_dict return patch_dict
@ -429,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches: for p in patches:
strength = p[0] strength = p[0]
v = p[1] v = p[1]
@ -471,6 +472,11 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set": elif patch_type == "set":
weight.copy_(v[0]) weight.copy_(v[0])
elif patch_type == "model_as_lora":
target_weight: torch.Tensor = v[0]
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)

View File

@ -33,12 +33,16 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model import comfy.ldm.lightricks.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension
import comfy.conds import comfy.conds
import comfy.ops import comfy.ops
from enum import Enum from enum import Enum
from . import utils from . import utils
import comfy.latent_formats import comfy.latent_formats
import math import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -95,6 +99,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device self.device = device
self.current_patcher: 'ModelPatcher' = None
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
@ -120,6 +125,13 @@ class BaseModel(torch.nn.Module):
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._apply_model,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
xc = self.model_sampling.calculate_input(sigma, x) xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:

View File

@ -16,6 +16,8 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
from typing import Optional, Callable
import torch import torch
import copy import copy
import inspect import inspect
@ -28,6 +30,9 @@ import comfy.utils
import comfy.float import comfy.float
import comfy.model_management import comfy.model_management
import comfy.lora import comfy.lora
import comfy.hooks
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
def string_to_seed(data): def string_to_seed(data):
@ -76,6 +81,17 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options["disable_cfg1_optimization"] = True model_options["disable_cfg1_optimization"] = True
return model_options return model_options
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
return new_hook_patches
def wipe_lowvram_weight(m): def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"): if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights m.comfy_cast_weights = m.prev_comfy_cast_weights
@ -119,6 +135,49 @@ def get_key_weight(model, key):
return weight, set_func, convert_func return weight, set_func, convert_func
class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
self.model = model
self.was_injected = False
self.prev_skip_injection = False
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = True
if self.model.is_injected:
self.model.eject_model()
self.was_injected = True
def __exit__(self, *args):
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = self.prev_skip_injection
self.model.inject_model()
if self.was_injected and not self.model.skip_injection:
self.model.inject_model()
self.model.skip_injection = self.prev_skip_injection
class MemoryCounter:
def __init__(self, initial: int, minimum=0):
self.value = initial
self.minimum = minimum
# TODO: add a safe limit besides 0
def use(self, weight: torch.Tensor):
weight_size = weight.nelement() * weight.element_size()
if self.is_useable(weight_size):
self.decrement(weight_size)
return True
return False
def is_useable(self, used: int):
return self.value - used > self.minimum
def decrement(self, used: int):
self.value -= used
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
@ -141,6 +200,24 @@ class ModelPatcher:
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self.parent = None self.parent = None
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers()
self.is_injected = False
self.skip_injection = False
self.injections: dict[str, list[PatcherInjection]] = {}
self.hook_patches: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
if not hasattr(self.model, 'model_loaded_weight_memory'): if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
@ -177,6 +254,47 @@ class ModelPatcher:
n.backup = self.backup n.backup = self.backup
n.object_patches_backup = self.object_patches_backup n.object_patches_backup = self.object_patches_backup
n.parent = self n.parent = self
# attachments
n.attachments = {}
for k in self.attachments:
if hasattr(self.attachments[k], "on_model_patcher_clone"):
n.attachments[k] = self.attachments[k].on_model_patcher_clone()
else:
n.attachments[k] = self.attachments[k]
# additional models
for k, c in self.additional_models.items():
n.additional_models[k] = [x.clone() for x in c]
# callbacks
for k, c in self.callbacks.items():
n.callbacks[k] = {}
for k1, c1 in c.items():
n.callbacks[k][k1] = c1.copy()
# sample wrappers
for k, w in self.wrappers.items():
n.wrappers[k] = {}
for k1, w1 in w.items():
n.wrappers[k][k1] = w1.copy()
# injection
n.is_injected = self.is_injected
n.skip_injection = self.skip_injection
for k, i in self.injections.items():
n.injections[k] = i.copy()
# hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
for group in self.cached_hook_patches:
n.cached_hook_patches[group] = {}
for k in self.cached_hook_patches[group]:
n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k]
n.hook_backup = self.hook_backup
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n return n
def is_clone(self, other): def is_clone(self, other):
@ -184,10 +302,29 @@ class ModelPatcher:
return True return True
return False return False
def clone_has_same_weights(self, clone): def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone): if not self.is_clone(clone):
return False return False
if self.current_hooks != clone.current_hooks:
return False
if self.forced_hooks != clone.forced_hooks:
return False
if self.hook_patches.keys() != clone.hook_patches.keys():
return False
if self.attachments.keys() != clone.attachments.keys():
return False
if self.additional_models.keys() != clone.additional_models.keys():
return False
for key in self.callbacks:
if len(self.callbacks[key]) != len(clone.callbacks[key]):
return False
for key in self.wrappers:
if len(self.wrappers[key]) != len(clone.wrappers[key]):
return False
if self.injections.keys() != clone.injections.keys():
return False
if len(self.patches) == 0 and len(clone.patches) == 0: if len(self.patches) == 0 and len(clone.patches) == 0:
return True return True
@ -256,6 +393,12 @@ class ModelPatcher:
def set_model_output_block_patch(self, patch): def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch") self.set_model_patch(patch, "output_block_patch")
def set_model_emb_patch(self, patch):
self.set_model_patch(patch, "emb_patch")
def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
@ -294,27 +437,28 @@ class ModelPatcher:
return self.model.get_dtype() return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set() with self.use_ejected():
model_sd = self.model.state_dict() p = set()
for k in patches: model_sd = self.model.state_dict()
offset = None for k in patches:
function = None offset = None
if isinstance(k, str): function = None
key = k if isinstance(k, str):
else: key = k
offset = k[1] else:
key = k[0] offset = k[1]
if len(k) > 2: key = k[0]
function = k[2] if len(k) > 2:
function = k[2]
if key in model_sd: if key in model_sd:
p.add(k) p.add(k)
current_patches = self.patches.get(key, []) current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function)) current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches self.patches[key] = current_patches
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
@ -324,9 +468,12 @@ class ModelPatcher:
if not k.startswith(filter_prefix): if not k.startswith(filter_prefix):
continue continue
bk = self.backup.get(k, None) bk = self.backup.get(k, None)
hbk = self.hook_backup.get(k, None)
weight, set_func, convert_func = get_key_weight(self.model, k) weight, set_func, convert_func = get_key_weight(self.model, k)
if bk is not None: if bk is not None:
weight = bk.weight weight = bk.weight
if hbk is not None:
weight = hbk[0]
if convert_func is None: if convert_func is None:
convert_func = lambda a, **kwargs: a convert_func = lambda a, **kwargs: a
@ -337,13 +484,14 @@ class ModelPatcher:
return p return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict() with self.use_ejected():
keys = list(sd.keys()) sd = self.model.state_dict()
if filter_prefix is not None: keys = list(sd.keys())
for k in keys: if filter_prefix is not None:
if not k.startswith(filter_prefix): for k in keys:
sd.pop(k) if not k.startswith(filter_prefix):
return sd sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False): def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches: if key not in self.patches:
@ -388,106 +536,117 @@ class ModelPatcher:
return loading return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0 with self.use_ejected():
patch_counter = 0 self.unpatch_hooks()
lowvram_counter = 0 mem_counter = 0
loading = self._load_list() patch_counter = 0
lowvram_counter = 0
loading = self._load_list()
load_completely = [] load_completely = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
n = x[1] n = x[1]
m = x[2] m = x[2]
params = x[3] params = x[3]
module_mem = x[0] module_mem = x[0]
lowvram_weight = False lowvram_weight = False
if not full_load and hasattr(m, "comfy_cast_weights"): if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
params = x[3]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue continue
weight_key = "{}.weight".format(n) for param in params:
bias_key = "{}.bias".format(n) self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
if lowvram_weight: logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
if weight_key in self.patches: m.comfy_patched_weights = True
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights for x in load_completely:
m.comfy_cast_weights = True x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else: else:
if hasattr(m, "comfy_cast_weights"): logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
if m.comfy_cast_weights: self.model.model_lowvram = False
wipe_lowvram_weight(m) if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
if full_load or mem_counter + module_mem < lowvram_model_memory: self.model.lowvram_patch_counter += patch_counter
mem_counter += module_mem self.model.device = device_to
load_completely.append((module_mem, n, m, params)) self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
load_completely.sort(reverse=True) for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
for x in load_completely: callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
n = x[1]
m = x[2]
params = x[3]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
for param in params: self.apply_hooks(self.forced_hooks, force_apply=True)
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches: with self.use_ejected():
old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) for k in self.object_patches:
if k not in self.object_patches_backup: old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
self.object_patches_backup[k] = old if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
full_load = True full_load = True
else: else:
full_load = False full_load = False
if load_weights: if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
self.inject_model()
return self.model return self.model
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
if unpatch_weights: if unpatch_weights:
self.unpatch_hooks()
if self.model.model_lowvram: if self.model.model_lowvram:
for m in self.model.modules(): for m in self.model.modules():
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
@ -523,85 +682,91 @@ class ModelPatcher:
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0 with self.use_ejected():
patch_counter = 0 memory_freed = 0
unload_list = self._load_list() patch_counter = 0
unload_list.sort() unload_list = self._load_list()
for unload in unload_list: unload_list.sort()
if memory_to_free < memory_freed: for unload in unload_list:
break if memory_to_free < memory_freed:
module_mem = unload[0] break
n = unload[1] module_mem = unload[0]
m = unload[2] n = unload[1]
params = unload[3] m = unload[2]
params = unload[3]
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True move_weight = True
for param in params: for param in params:
key = "{}.{}".format(n, param) key = "{}.{}".format(n, param)
bk = self.backup.get(key, None) bk = self.backup.get(key, None)
if bk is not None: if bk is not None:
if not lowvram_possible: if not lowvram_possible:
move_weight = False move_weight = False
break break
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
comfy.utils.set_attr_param(self.model, key, bk.weight) comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key) self.backup.pop(key)
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if move_weight: if move_weight:
m.to(device_to) m.to(device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches) m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches) m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1 patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) with self.use_ejected(skip_and_inject_on_exit_only=True):
# TODO: force_patch_weights should not unload + reload full model unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
used = self.model.model_loaded_weight_memory # TODO: force_patch_weights should not unload + reload full model
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) used = self.model.model_loaded_weight_memory
if unpatch_weights: self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
extra_memory += (used - self.model.model_loaded_weight_memory) if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False) self.patch_model(load_weights=False)
full_load = False full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
return 0 self.apply_hooks(self.forced_hooks, force_apply=True)
if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): return 0
full_load = True if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
current_used = self.model.model_loaded_weight_memory full_load = True
try: current_used = self.model.model_loaded_weight_memory
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) try:
except Exception as e: self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
self.detach() except Exception as e:
raise e self.detach()
raise e
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
def detach(self, unpatch_all=True): def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device) self.model_patches_to(self.offload_device)
if unpatch_all: if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model return self.model
def current_loaded_device(self): def current_loaded_device(self):
@ -611,6 +776,345 @@ class ModelPatcher:
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self):
self.clean_hooks()
if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
callback(self)
def add_callback(self, call_type: str, callback: Callable):
self.add_callback_with_key(call_type, None, callback)
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def remove_callbacks_with_key(self, call_type: str, key: str):
c = self.callbacks.get(call_type, {})
if key in c:
c.pop(key)
def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, [])
def get_all_callbacks(self, call_type: str):
c_list = []
for c in self.callbacks.get(call_type, {}).values():
c_list.extend(c)
return c_list
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {})
if key in w:
w.pop(key)
def get_wrappers(self, wrapper_type: str, key: str):
return self.wrappers.get(wrapper_type, {}).get(key, [])
def get_all_wrappers(self, wrapper_type: str):
w_list = []
for w in self.wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list
def set_attachments(self, key: str, attachment):
self.attachments[key] = attachment
def remove_attachments(self, key: str):
if key in self.attachments:
self.attachments.pop(key)
def get_attachment(self, key: str):
return self.attachments.get(key, None)
def set_injections(self, key: str, injections: list[PatcherInjection]):
self.injections[key] = injections
def remove_injections(self, key: str):
if key in self.injections:
self.injections.pop(key)
def set_additional_models(self, key: str, models: list['ModelPatcher']):
self.additional_models[key] = models
def remove_additional_models(self, key: str):
if key in self.additional_models:
self.additional_models.pop(key)
def get_additional_models_with_key(self, key: str):
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
def get_nested_additional_models(self):
def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
'''Make sure circular references do not cause infinite recursion.'''
next_models = []
for model in prev_models:
candidates = model.get_additional_models()
for c in candidates:
if c not in cache_set:
next_models.append(c)
cache_set.add(c)
if len(next_models) == 0:
return prev_models
return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
all_models = self.get_additional_models()
models_set = set(all_models)
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
return real_all_models
def use_ejected(self, skip_and_inject_on_exit_only=False):
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
def inject_model(self):
if self.is_injected or self.skip_injection:
return
for injections in self.injections.values():
for inj in injections:
inj.inject(self)
self.is_injected = True
if self.is_injected:
for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
callback(self)
def eject_model(self):
if not self.is_injected:
return
for injections in self.injections.values():
for inj in injections:
inj.eject(self)
self.is_injected = False
for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
callback(self)
def pre_run(self):
if hasattr(self.model, "current_patcher"):
self.model.current_patcher = self
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
def restore_hook_patches(self):
if len(self.hook_patches_backup) > 0:
self.hook_patches = self.hook_patches_backup
self.hook_patches_backup = {}
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
curr_t = t[0]
reset_current_hooks = False
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
hook.add_hook_patches(self, model_options, target, registered_hooks)
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
if len(weight_hooks_to_register) > 0:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target, registered_hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():
# NOTE: this mirrors behavior of add_patches func
current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {})
p = set()
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hook_patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
current_hook_patches[key] = current_patches
self.hook_patches[hook.hook_ref] = current_hook_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
self.patches_uuid = uuid.uuid4()
return list(p)
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
# combined_patches will contain weights of all relevant hooks, per key
combined_patches = {}
if hooks is not None:
for hook in hooks.hooks:
hook_patches: dict = self.hook_patches.get(hook.hook_ref, {})
for key in hook_patches.keys():
current_patches: list[tuple] = combined_patches.get(key, [])
if math.isclose(hook.strength, 1.0):
current_patches.extend(hook_patches[key])
else:
# patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model)
for patch in hook_patches[key]:
new_patch = list(patch)
new_patch[0] *= hook.strength
current_patches.append(tuple(new_patch))
combined_patches[key] = current_patches
return combined_patches
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)
return {}
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
self.unpatch_hooks()
if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
minimum=comfy.model_management.minimum_inference_memory()*2)
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
for key in cached_weights:
if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
if key not in self.hook_backup:
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
target_device = self.offload_device
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
used = memory_counter.use(weight)
if used:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
def clear_cached_hook_weights(self):
self.cached_hook_patches.clear()
self.patch_hooks(None)
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
weight, set_func, convert_func = get_key_weight(self.model, key)
weight: torch.Tensor
if key not in self.hook_backup:
target_device = self.offload_device
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
used = memory_counter.use(weight)
if used:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
# TODO: properly handle LowVramPatch, if it ends up an issue
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(combined_patches[key],
temp_weight,
key, original_weights=original_weights)
del original_weights[key]
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
target_device = self.offload_device
used = memory_counter.use(weight)
if used:
target_device = weight.device
self.cached_hook_patches.setdefault(hooks, {})
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
del temp_weight
del out_weight
del weight
def unpatch_hooks(self) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
self.current_hooks = None
return
keys = list(self.hook_backup.keys())
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None
def clean_hooks(self):
self.unpatch_hooks()
self.clear_cached_hook_weights()
def __del__(self): def __del__(self):
self.detach(unpatch_all=False) self.detach(unpatch_all=False)

156
comfy/patcher_extension.py Normal file
View File

@ -0,0 +1,156 @@
from __future__ import annotations
from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"
ON_PRE_RUN = "on_pre_run"
ON_PREPARE_STATE = "on_prepare_state"
ON_APPLY_HOOKS = "on_apply_hooks"
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
ON_INJECT_MODEL = "on_inject_model"
ON_EJECT_MODEL = "on_eject_model"
# callbacks dict is in the format:
# {"call_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
c_list.extend(callbacks.get(call_type, {}).get(key, []))
return c_list
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
for c in callbacks.get(call_type, {}).values():
c_list.extend(c)
return c_list
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
# wrappers dict is in the format:
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
return w_list
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
for w in wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list
class WrapperExecutor:
"""Handles call stack of wrappers around a function in an ordered manner."""
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
# NOTE: class_obj exists so that wrappers surrounding a class method can access
# the class instance at runtime via executor.class_obj
self.original = original
self.class_obj = class_obj
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, *args, **kwargs)
def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)
class PatcherInjection:
def __init__(self, inject: Callable, eject: Callable):
self.inject = inject
self.eject = eject
def copy_nested_dicts(input_dict: dict):
new_dict = input_dict.copy()
for key, value in input_dict.items():
if isinstance(value, dict):
new_dict[key] = copy_nested_dicts(value)
elif isinstance(value, list):
new_dict[key] = value.copy()
return new_dict
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
if copy_dict1:
merged_dict = copy_nested_dicts(dict1)
else:
merged_dict = dict1
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(value, curr_value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:
merged_dict[key] = value
return merged_dict

View File

@ -1,7 +1,16 @@
from __future__ import annotations
import uuid
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.utils import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
return comfy.utils.reshape_mask(noise_mask, shape).to(device) return comfy.utils.reshape_mask(noise_mask, shape).to(device)
@ -10,9 +19,43 @@ def get_models_from_cond(cond, model_type):
models = [] models = []
for c in cond: for c in cond:
if model_type in c: if model_type in c:
models += [c[model_type]] if isinstance(c[model_type], list):
models += c[model_type]
else:
models += [c[model_type]]
return models return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
if 'control' in c:
cnets.append(c['control'])
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
if cnet.extra_hooks is not None:
_list.append(cnet.extra_hooks)
if cnet.previous_controlnet is None:
return _list
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
hooks_list = []
cnets = set(cnets)
for base_cnet in cnets:
get_extra_hooks_from_cnet(base_cnet, hooks_list)
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
return hooks_dict
def convert_cond(cond): def convert_cond(cond):
out = [] out = []
for c in cond: for c in cond:
@ -22,17 +65,22 @@ def convert_cond(cond):
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0] temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
out.append(temp) out.append(temp)
return out return out
def get_additional_models(conds, dtype): def get_additional_models(conds, dtype):
"""loads additional models in conditioning""" """loads additional models in conditioning"""
cnets = [] cnets: list[ControlBase] = []
gligen = [] gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds: for k in conds:
cnets += get_models_from_cond(conds[k], "control") cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen") gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets) control_nets = set(cnets)
@ -43,7 +91,9 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype) inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
models = control_models + gligen hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
return models, inference_memory return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
@ -53,10 +103,11 @@ def cleanup_additional_models(models):
m.cleanup() m.cleanup()
def prepare_sampling(model, noise_shape, conds): def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device device = model.load_device
real_model = None real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
@ -72,3 +123,14 @@ def cleanup_models(conds, models):
control_cleanup += get_models_from_cond(conds[k], "control") control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup)) cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)

View File

@ -1,11 +1,21 @@
from __future__ import annotations
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
import torch import torch
import collections import collections
from comfy import model_management from comfy import model_management
import math import math
import logging import logging
import comfy.samplers
import comfy.sampler_helpers import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import scipy.stats import scipy.stats
import numpy import numpy
@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
for c in model_conds: for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None) control = conds.get('control', None)
patches = None patches = None
@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
patches['middle_patch'] = [gligen_patch] patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
return cond_obj(input_x, mult, conditioning, area, control, patches) return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
@ -138,110 +149,184 @@ def cond_cat(c_list):
return out return out
def calc_cond_batch(model, conds, x_in, timestep, model_options): def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
# need to figure out remaining unmasked area for conds
default_mults = []
for _ in default_conds:
default_mults.append(torch.ones_like(x_in))
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
for lora_hooks, to_run in hooked_to_run.items():
for cond_obj, i in to_run:
# if no default_cond for cond_type, do nothing
if len(default_conds[i]) == 0:
continue
area: list[int] = cond_obj.area
if area is not None:
curr_default_mult: torch.Tensor = default_mults[i]
dims = len(area) // 2
for i in range(dims):
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
curr_default_mult -= cond_obj.mult
else:
default_mults[i] -= cond_obj.mult
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
for i, mult in enumerate(default_mults):
# if no default_cond for cond type, do nothing
if len(default_conds[i]) == 0:
continue
torch.nn.functional.relu(mult, inplace=True)
# if mult is all zeros, then don't add default_cond
if torch.max(mult) == 0.0:
continue
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
p = p._replace(mult=mult)
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = [] out_conds = []
out_counts = [] out_counts = []
to_run = [] # separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
for i in range(len(conds)): for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in)) out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37) out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i] cond = conds[i]
default_c = []
if cond is not None: if cond is not None:
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in, timestep) if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None: if p is None:
continue continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
to_run += [(p, i)] if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
while len(to_run) > 0: model.current_patcher.prepare_state(timestep)
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse() # run every hooked_to_run separately
to_batch = to_batch_temp[:1] for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
free_memory = model_management.get_free_memory(x_in.device) to_batch_temp.reverse()
for i in range(1, len(to_batch_temp) + 1): to_batch = to_batch_temp[:1]
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
input_x = [] free_memory = model_management.get_free_memory(x_in.device)
mult = [] for i in range(1, len(to_batch_temp) + 1):
c = [] batch_amount = to_batch_temp[:len(to_batch_temp)//i]
cond_or_uncond = [] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
area = [] if model.memory_required(input_shape) * 1.5 < free_memory:
control = None to_batch = batch_amount
patches = None break
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond) input_x = []
input_x = torch.cat(input_x) mult = []
c = cond_cat(c) c = []
timestep_ = torch.cat([timestep] * batch_chunks) cond_or_uncond = []
uuids = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
if control is not None: batch_chunks = len(cond_or_uncond)
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {} transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy() transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None: if patches is not None:
if "patches" in transformer_options: # TODO: replace with merge_nested_dicts function
cur_patches = transformer_options["patches"].copy() if "patches" in transformer_options:
for p in patches: cur_patches = transformer_options["patches"].copy()
if p in cur_patches: for p in patches:
cur_patches[p] = cur_patches[p] + patches[p] if p in cur_patches:
else: cur_patches[p] = cur_patches[p] + patches[p]
cur_patches[p] = patches[p] else:
transformer_options["patches"] = cur_patches cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
transformer_options["patches"] = patches output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
transformer_options["cond_or_uncond"] = cond_or_uncond[:] for o in range(batch_chunks):
transformer_options["sigmas"] = timestep cond_index = cond_or_uncond[o]
a = area[o]
c['transformer_options'] = transformer_options if a is None:
out_conds[cond_index] += output[o] * mult[o]
if 'model_function_wrapper' in model_options: out_counts[cond_index] += mult[o]
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else:
else: out_c = out_conds[cond_index]
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) out_cts = out_counts[cond_index]
dims = len(a) // 2
for o in range(batch_chunks): for i in range(dims):
cond_index = cond_or_uncond[o] out_c = out_c.narrow(i + 2, a[i + dims], a[i])
a = area[o] out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
if a is None: out_c += output[o] * mult[o]
out_conds[cond_index] += output[o] * mult[o] out_cts += mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)): for i in range(len(out_conds)):
out_conds[i] /= out_counts[i] out_conds[i] /= out_counts[i]
@ -500,10 +585,15 @@ def calculate_start_end_timesteps(model, conds):
timestep_start = None timestep_start = None
timestep_end = None timestep_end = None
if 'start_percent' in x: # handle clip hook schedule, if needed
timestep_start = s.percent_to_sigma(x['start_percent']) if 'clip_start_percent' in x:
if 'end_percent' in x: timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
timestep_end = s.percent_to_sigma(x['end_percent']) timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
else:
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None): if (timestep_start is not None) or (timestep_end is not None):
n = x.copy() n = x.copy()
@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if k != kk: if k != kk:
create_cond_with_same_area_if_none(conds[kk], c) create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
for c in conds[k]:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook.initialize_timesteps(model)
for k in conds: for k in conds:
pre_run_control(model, conds[k]) pre_run_control(model, conds[k])
@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds return conds
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
# determine which ControlNets have extra_hooks that should be combined with normal hooks
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
for k in conds:
for kk in conds[k]:
if 'control' in kk:
control: 'ControlBase' = kk['control']
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
to_replace = hook_replacement.setdefault((control, hooks), [])
to_replace.append(kk)
# if nothing to replace, do nothing
if len(hook_replacement) == 0:
return
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
# on the cond dicts
for key, conds_to_modify in hook_replacement.items():
control = key[0]
hooks = key[1]
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
# if combined hooks are not None, set as new hooks for all relevant conds
if hooks is not None:
for cond in conds_to_modify:
cond['hooks'] = hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
for k in conds:
for kk in conds[k]:
hooks_set.add(kk.get('hooks', None))
return len(hooks_set)
class CFGGuider: class CFGGuider:
def __init__(self, model_patcher): def __init__(self, model_patcher):
self.model_patcher = model_patcher self.model_patcher: 'ModelPatcher' = model_patcher
self.model_options = model_patcher.model_options self.model_options = model_patcher.model_options
self.original_conds = {} self.original_conds = {}
self.cfg = 1.0 self.cfg = 1.0
@ -714,19 +847,17 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed} extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample,
sampler,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32)) return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device device = self.model_patcher.load_device
@ -737,14 +868,48 @@ class CFGGuider:
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model del self.inner_model
del self.conds
del self.loaded_models del self.loaded_models
return output return output
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
preprocess_conds_hooks(self.conds)
try:
orig_model_options = self.model_options
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
orig_hook_mode = self.model_patcher.hook_mode
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
del self.conds
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model) cfg_guider = CFGGuider(model)

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import torch import torch
from enum import Enum from enum import Enum
import logging import logging
from comfy import model_management from comfy import model_management
from comfy.utils import ProgressBar
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.cascade.stage_c_coder import StageC_coder
@ -33,6 +35,7 @@ import comfy.text_encoders.lt
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.lora_convert import comfy.lora_convert
import comfy.hooks
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.taesd.taesd import comfy.taesd.taesd
@ -98,9 +101,13 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if params['device'] == load_device: if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True) model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
self.use_clip_schedule = False
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self): def clone(self):
@ -109,6 +116,8 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
@ -120,6 +129,69 @@ class CLIP:
def tokenize(self, text, return_word_ids=False): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
pooled_dict["hooks"] = self.apply_hooks_to_conds
return pooled_dict
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
if all_hooks is None or not self.use_clip_schedule:
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
return_pooled = "unprojected" if unprojected else True
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
cond = pooled_dict.pop("cond")
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
all_cond_pooled.append([cond, pooled_dict])
else:
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options() self.cond_stage_model.reset_clip_options()
@ -137,6 +209,7 @@ class CLIP:
if len(o) > 2: if len(o) > 2:
for k in o[2]: for k in o[2]:
out[k] = o[2][k] out[k] = o[2][k]
self.add_hooks_to_dict(out)
return out return out
if return_pooled: if return_pooled:

View File

@ -17,8 +17,7 @@ class CLIPTextEncodeSDXLRefiner:
def encode(self, clip, ascore, width, height, text): def encode(self, clip, ascore, width, height, text):
tokens = clip.tokenize(text) tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
class CLIPTextEncodeSDXL: class CLIPTextEncodeSDXL:
@classmethod @classmethod
@ -47,8 +46,7 @@ class CLIPTextEncodeSDXL:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,

View File

@ -18,10 +18,7 @@ class CLIPTextEncodeFlux:
tokens = clip.tokenize(clip_l) tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
cond = output.pop("cond")
output["guidance"] = guidance
return ([[cond, output]], )
class FluxGuidance: class FluxGuidance:
@classmethod @classmethod

697
comfy_extras/nodes_hooks.py Normal file
View File

@ -0,0 +1,697 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
import torch
from collections.abc import Iterable
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.sd import CLIP
import comfy.hooks
import comfy.sd
import comfy.utils
import folder_paths
###########################################
# Mask, Combine, and Hook Conditioning
#------------------------------------------
class PairConditioningSetProperties:
NodeId = 'PairConditioningSetProperties'
NodeName = 'Cond Pair Set Props'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive_NEW": ("CONDITIONING", ),
"negative_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_properties"
def set_properties(self, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative)
class PairConditioningSetPropertiesAndCombine:
NodeId = 'PairConditioningSetPropertiesAndCombine'
NodeName = 'Cond Pair Set Props Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"positive_NEW": ("CONDITIONING", ),
"negative_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_properties"
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative)
class ConditioningSetProperties:
NodeId = 'ConditioningSetProperties'
NodeName = 'Cond Set Props'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_properties"
def set_properties(self, cond_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_cond,)
class ConditioningSetPropertiesAndCombine:
NodeId = 'ConditioningSetPropertiesAndCombine'
NodeName = 'Cond Set Props Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING", ),
"cond_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"mask": ("MASK", ),
"hooks": ("HOOKS",),
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_properties"
def set_properties(self, cond, cond_NEW,
strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_cond,)
class PairConditioningCombine:
NodeId = 'PairConditioningCombine'
NodeName = 'Cond Pair Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive_A": ("CONDITIONING",),
"negative_A": ("CONDITIONING",),
"positive_B": ("CONDITIONING",),
"negative_B": ("CONDITIONING",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "combine"
def combine(self, positive_A, negative_A, positive_B, negative_B):
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
return (final_positive, final_negative,)
class PairConditioningSetDefaultAndCombine:
NodeId = 'PairConditioningSetDefaultCombine'
NodeName = 'Cond Pair Set Default Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"positive_DEFAULT": ("CONDITIONING",),
"negative_DEFAULT": ("CONDITIONING",),
},
"optional": {
"hooks": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_default_and_combine"
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
hooks: comfy.hooks.HookGroup=None):
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
hooks=hooks)
return (final_positive, final_negative)
class ConditioningSetDefaultAndCombine:
NodeId = 'ConditioningSetDefaultCombine'
NodeName = 'Cond Set Default Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING",),
"cond_DEFAULT": ("CONDITIONING",),
},
"optional": {
"hooks": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_default_and_combine"
def set_default_and_combine(self, cond, cond_DEFAULT,
hooks: comfy.hooks.HookGroup=None):
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
hooks=hooks)
return (final_conditioning,)
class SetClipHooks:
NodeId = 'SetClipHooks'
NodeName = 'Set CLIP Hooks'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip": ("CLIP",),
"apply_to_conds": ("BOOLEAN", {"default": True}),
"schedule_clip": ("BOOLEAN", {"default": False})
},
"optional": {
"hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CLIP",)
CATEGORY = "advanced/hooks/clip"
FUNCTION = "apply_hooks"
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
if apply_to_conds:
clip.apply_hooks_to_conds = hooks
clip.patcher.forced_hooks = hooks.clone()
clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
return (clip,)
class ConditioningTimestepsRange:
NodeId = 'ConditioningTimestepsRange'
NodeName = 'Timesteps Range'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
CATEGORY = "advanced/hooks"
FUNCTION = "create_range"
def create_range(self, start_percent: float, end_percent: float):
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
#------------------------------------------
###########################################
###########################################
# Create Hooks
#------------------------------------------
class CreateHookLora:
NodeId = 'CreateHookLora'
NodeName = 'Create Hook LoRA'
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook"
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
if prev_hooks is None:
prev_hooks = comfy.hooks.HookGroup()
prev_hooks.clone()
if strength_model == 0 and strength_clip == 0:
return (prev_hooks,)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
return (prev_hooks.clone_and_combine(hooks),)
class CreateHookLoraModelOnly(CreateHookLora):
NodeId = 'CreateHookLoraModelOnly'
NodeName = 'Create Hook LoRA (MO)'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook_model_only"
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
class CreateHookModelAsLora:
NodeId = 'CreateHookModelAsLora'
NodeName = 'Create Hook Model as LoRA'
def __init__(self):
# when not None, will be in following format:
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
self.loaded_weights = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook"
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
prev_hooks: comfy.hooks.HookGroup=None):
if prev_hooks is None:
prev_hooks = comfy.hooks.HookGroup()
prev_hooks.clone()
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
weights_model = None
weights_clip = None
if self.loaded_weights is not None:
if self.loaded_weights[0] == ckpt_path:
weights_model = self.loaded_weights[1]
weights_clip = self.loaded_weights[2]
else:
temp = self.loaded_weights
self.loaded_weights = None
del temp
if weights_model is None:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
strength_model=strength_model, strength_clip=strength_clip)
return (prev_hooks.clone_and_combine(hooks),)
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
NodeId = 'CreateHookModelAsLoraModelOnly'
NodeName = 'Create Hook Model as LoRA (MO)'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
},
"optional": {
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
FUNCTION = "create_hook_model_only"
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
prev_hooks: comfy.hooks.HookGroup=None):
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
#------------------------------------------
###########################################
###########################################
# Schedule Hooks
#------------------------------------------
class SetHookKeyframes:
NodeId = 'SetHookKeyframes'
NodeName = 'Set Hook Keyframes'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"hooks": ("HOOKS",),
},
"optional": {
"hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "set_hook_keyframes"
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
if hook_kf is not None:
hooks = hooks.clone()
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
return (hooks,)
class CreateHookKeyframe:
NodeId = 'CreateHookKeyframe'
NodeName = 'Create Hook Keyframe'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "create_hook_keyframe"
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
if prev_hook_kf is None:
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
prev_hook_kf = prev_hook_kf.clone()
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
prev_hook_kf.add(keyframe)
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
NodeId = 'CreateHookKeyframesFromFloats'
NodeName = 'Create Hook Keyframes From Floats'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"print_keyframes": ("BOOLEAN", {"default": False}),
},
"optional": {
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
CATEGORY = "advanced/hooks/scheduling"
FUNCTION = "create_hook_keyframes"
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
start_percent: float, end_percent: float,
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
if prev_hook_kf is None:
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
prev_hook_kf = prev_hook_kf.clone()
if type(floats_strength) in (float, int):
floats_strength = [float(floats_strength)]
elif isinstance(floats_strength, Iterable):
pass
else:
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
method=comfy.hooks.InterpolationMethod.LINEAR)
is_first = True
for percent, strength in zip(percents, floats_strength):
guarantee_steps = 0
if is_first:
guarantee_steps = 1
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
#------------------------------------------
###########################################
class SetModelHooksOnCond:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"conditioning": ("CONDITIONING",),
"hooks": ("HOOKS",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/manual"
FUNCTION = "attach_hook"
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
###########################################
# Combine Hooks
#------------------------------------------
class CombineHooks:
NodeId = 'CombineHooks2'
NodeName = 'Combine Hooks [2]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
class CombineHooksFour:
NodeId = 'CombineHooks4'
NodeName = 'Combine Hooks [4]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
"hooks_C": ("HOOKS",),
"hooks_D": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None,
hooks_C: comfy.hooks.HookGroup=None,
hooks_D: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
class CombineHooksEight:
NodeId = 'CombineHooks8'
NodeName = 'Combine Hooks [8]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
"hooks_C": ("HOOKS",),
"hooks_D": ("HOOKS",),
"hooks_E": ("HOOKS",),
"hooks_F": ("HOOKS",),
"hooks_G": ("HOOKS",),
"hooks_H": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None,
hooks_C: comfy.hooks.HookGroup=None,
hooks_D: comfy.hooks.HookGroup=None,
hooks_E: comfy.hooks.HookGroup=None,
hooks_F: comfy.hooks.HookGroup=None,
hooks_G: comfy.hooks.HookGroup=None,
hooks_H: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
#------------------------------------------
###########################################
node_list = [
# Create
CreateHookLora,
CreateHookLoraModelOnly,
CreateHookModelAsLora,
CreateHookModelAsLoraModelOnly,
# Scheduling
SetHookKeyframes,
CreateHookKeyframe,
CreateHookKeyframesFromFloats,
# Combine
CombineHooks,
CombineHooksFour,
CombineHooksEight,
# Attach
ConditioningSetProperties,
ConditioningSetPropertiesAndCombine,
PairConditioningSetProperties,
PairConditioningSetPropertiesAndCombine,
ConditioningSetDefaultAndCombine,
PairConditioningSetDefaultAndCombine,
PairConditioningCombine,
SetClipHooks,
# Other
ConditioningTimestepsRange,
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName

View File

@ -15,9 +15,7 @@ class CLIPTextEncodeHunyuanDiT:
tokens = clip.tokenize(bert) tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) return (clip.encode_from_tokens_scheduled(tokens), )
cond = output.pop("cond")
return ([[cond, output]], )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -82,8 +82,7 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) return (clip.encode_from_tokens_scheduled(tokens), )
return ([[cond, {"pooled_output": pooled}]], )
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):

View File

@ -62,9 +62,8 @@ class CLIPTextEncode:
def encode(self, clip, text): def encode(self, clip, text):
tokens = clip.tokenize(text) tokens = clip.tokenize(text)
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) return (clip.encode_from_tokens_scheduled(tokens), )
cond = output.pop("cond")
return ([[cond, output]], )
class ConditioningCombine: class ConditioningCombine:
@classmethod @classmethod
@ -2149,6 +2148,7 @@ def init_builtin_extra_nodes():
"nodes_mochi.py", "nodes_mochi.py",
"nodes_slg.py", "nodes_slg.py",
"nodes_lt.py", "nodes_lt.py",
"nodes_hooks.py",
] ]
import_failed = [] import_failed = []