This commit is contained in:
Jedrzej Kosinski 2025-01-09 07:12:35 -05:00 committed by GitHub
commit c3e8262e5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 385 additions and 193 deletions

View File

@ -16,91 +16,132 @@ import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
from node_helpers import conditioning_set_values from node_helpers import conditioning_set_values
# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################
class EnumHookMode(enum.Enum): class EnumHookMode(enum.Enum):
'''
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
MinVram: No caching will occur for any operations related to hooks.
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
'''
MinVram = "minvram" MinVram = "minvram"
MaxSpeed = "maxspeed" MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum): class EnumHookType(enum.Enum):
'''
Hook types, each of which has different expected behavior.
'''
Weight = "weight" Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch" ObjectPatch = "object_patch"
AddModels = "add_models" AdditionalModels = "add_models"
Callbacks = "callbacks" TransformerOptions = "transformer_options"
Wrappers = "wrappers" Injections = "add_injections"
SetInjections = "add_injections"
class EnumWeightTarget(enum.Enum): class EnumWeightTarget(enum.Enum):
Model = "model" Model = "model"
Clip = "clip" Clip = "clip"
class EnumHookScope(enum.Enum):
'''
Determines if hook should be limited in its influence over sampling.
AllConditioning: hook will affect all conds used in sampling.
HookedOnly: hook will only affect the conds it was attached to.
'''
AllConditioning = "all_conditioning"
HookedOnly = "hooked_only"
class _HookRef: class _HookRef:
pass pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how custom_should_register function can look like.'''
return True return True
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
'''Creates base dictionary for use with Hooks' target param.'''
d = {}
if target is not None:
d['target'] = target
d.update(kwargs)
return d
class Hook: class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None): hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type self.hook_type = hook_type
'''Enum identifying the general class of this hook.'''
self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_ref = hook_ref if hook_ref else _HookRef()
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
self.hook_id = hook_id self.hook_id = hook_id
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property @property
def strength(self): def strength(self):
return self.hook_keyframe.strength return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'): def initialize_timesteps(self, model: BaseModel):
self.reset() self.reset()
self.hook_keyframe.initialize_timesteps(model) self.hook_keyframe.initialize_timesteps(model)
def reset(self): def reset(self):
self.hook_keyframe.reset() self.hook_keyframe.reset()
def clone(self, subtype: Callable=None): def clone(self):
if subtype is None: c: Hook = self.__class__()
subtype = type(self)
c: Hook = subtype()
c.hook_type = self.hook_type c.hook_type = self.hook_type
c.hook_ref = self.hook_ref c.hook_ref = self.hook_ref
c.hook_id = self.hook_id c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target, registered) return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): def __eq__(self, other: Hook):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self): def __hash__(self):
return hash(self.hook_ref) return hash(self.hook_ref)
class WeightHook(Hook): class WeightHook(Hook):
'''
Hook responsible for tracking weights to be applied to some model/clip.
Note, value of hook_scope is ignored and is treated as HookedOnly.
'''
def __init__(self, strength_model=1.0, strength_clip=1.0): def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight) super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
self.weights: dict = None self.weights: dict = None
self.weights_clip: dict = None self.weights_clip: dict = None
self.need_weight_init = True self.need_weight_init = True
self._strength_model = strength_model self._strength_model = strength_model
self._strength_clip = strength_clip self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
@property @property
def strength_model(self): def strength_model(self):
@ -110,36 +151,36 @@ class WeightHook(Hook):
def strength_clip(self): def strength_clip(self):
return self._strength_clip * self.strength return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target, registered): if not self.should_register(model, model_options, target_dict, registered):
return False return False
weights = None weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model target = target_dict.get('target', None)
else: if target == EnumWeightTarget.Clip:
strength = self._strength_clip strength = self._strength_clip
else:
strength = self._strength_model
if self.need_weight_init: if self.need_weight_init:
key_map = {} key_map = {}
if target == EnumWeightTarget.Model: if target == EnumWeightTarget.Clip:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else: else:
if target == EnumWeightTarget.Model: if target == EnumWeightTarget.Clip:
weights = self.weights
else:
weights = self.weights_clip weights = self.weights_clip
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self) registered.add(self)
return True return True
# TODO: add logs about any keys that were not applied # TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None): def clone(self):
if subtype is None: c: WeightHook = super().clone()
subtype = type(self)
c: WeightHook = super().clone(subtype)
c.weights = self.weights c.weights = self.weights
c.weights_clip = self.weights_clip c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init c.need_weight_init = self.need_weight_init
@ -147,127 +188,158 @@ class WeightHook(Hook):
c._strength_clip = self._strength_clip c._strength_clip = self._strength_clip
return c return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook): class ObjectPatchHook(Hook):
def __init__(self): def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch) super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None self.object_patches = object_patches
self.hook_scope = hook_scope
def clone(self, subtype: Callable=None): def clone(self):
if subtype is None: c: ObjectPatchHook = super().clone()
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches c.object_patches = self.object_patches
return c return c
# TODO: add functionality
class AddModelsHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
def __init__(self, key: str=None, models: list['ModelPatcher']=None): raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key class AdditionalModelsHook(Hook):
'''
Hook responsible for telling model management any additional models that should be loaded.
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AdditionalModels)
self.models = models self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key self.key = key
self.callback = callback
def clone(self, subtype: Callable=None): def clone(self):
if subtype is None: c: AdditionalModelsHook = super().clone()
subtype = type(self) c.models = self.models.copy() if self.models else self.models
c: CallbackHook = super().clone(subtype)
c.key = self.key c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target, registered): if not self.should_register(model, model_options, target_dict, registered):
return False return False
add_model_options = {"transformer_options": self.wrappers_dict} registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True return True
class SetInjectionsHook(Hook): class TransformerOptionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None): '''
super().__init__(hook_type=EnumHookType.SetInjections) Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = transformers_dict
self.hook_scope = hook_scope
self._skip_adding = False
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
def clone(self):
c: TransformerOptionsHook = super().clone()
c.transformers_dict = self.transformers_dict
c._skip_adding = self._skip_adding
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
self._skip_adding = False
if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
# skip_adding if included in AllConditioning to avoid double loading
self._skip_adding = True
else:
add_model_options = {"to_load_options": self.transformers_dict}
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
if not self._skip_adding:
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
class InjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key self.key = key
self.injections = injections self.injections = injections
self.hook_scope = hook_scope
def clone(self, subtype: Callable=None): def clone(self):
if subtype is None: c: InjectionsHook = super().clone()
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections c.injections = self.injections.copy() if self.injections else self.injections
return c return c
def add_hook_injections(self, model: 'ModelPatcher'): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
# TODO: add functionality raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
pass
class HookGroup: class HookGroup:
'''
Stores groups of hooks, and allows them to be queried by type.
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
always use the provided functions on HookGroup.
'''
def __init__(self): def __init__(self):
self.hooks: list[Hook] = [] self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook): def add(self, hook: Hook):
if hook not in self.hooks: if hook not in self.hooks:
self.hooks.append(hook) self.hooks.append(hook)
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
def remove(self, hook: Hook):
if hook in self.hooks:
self.hooks.remove(hook)
self._hook_dict[hook.hook_type].remove(hook)
def get_type(self, hook_type: EnumHookType):
return self._hook_dict.get(hook_type, [])
def contains(self, hook: Hook): def contains(self, hook: Hook):
return hook in self.hooks return hook in self.hooks
def is_subset_of(self, other: HookGroup):
self_hooks = set(self.hooks)
other_hooks = set(other.hooks)
return self_hooks.issubset(other_hooks)
def new_with_common_hooks(self, other: HookGroup):
c = HookGroup()
for hook in self.hooks:
if other.contains(hook):
c.add(hook.clone())
return c
def clone(self): def clone(self):
c = HookGroup() c = HookGroup()
for hook in self.hooks: for hook in self.hooks:
c.add(hook.clone()) c.add(hook.clone())
return c return c
def clone_and_combine(self, other: 'HookGroup'): def clone_and_combine(self, other: HookGroup):
c = self.clone() c = self.clone()
if other is not None: if other is not None:
for hook in other.hooks: for hook in other.hooks:
c.add(hook.clone()) c.add(hook.clone())
return c return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
if hook_kf is None: if hook_kf is None:
hook_kf = HookKeyframeGroup() hook_kf = HookKeyframeGroup()
else: else:
@ -275,36 +347,29 @@ class HookGroup:
for hook in self.hooks: for hook in self.hooks:
hook.hook_keyframe = hook_kf hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self): def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks: # only care about WeightHooks, for now
# only care about WeightHooks, for now for hook in self.get_type(EnumHookType.Weight):
if hook.hook_type == EnumHookType.Weight: hook: WeightHook
hook_schedule = [] hook_schedule = []
# if no hook keyframes, assign default value # if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0: if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None)) hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples # hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = [] all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values(): for range_kfs in scheduled_hooks.values():
@ -336,7 +401,7 @@ class HookGroup:
hook.reset() hook.reset()
@staticmethod @staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
actual: list[HookGroup] = [] actual: list[HookGroup] = []
for group in hooks_list: for group in hooks_list:
if group is not None: if group is not None:
@ -433,7 +498,7 @@ class HookKeyframeGroup:
c._set_first_as_current() c._set_first_as_current()
return c return c
def initialize_timesteps(self, model: 'BaseModel'): def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes: for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
@ -522,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
sorted_list.extend(object_list) sorted_list.extend(object_list)
return sorted_list return sorted_list
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup() hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
@ -548,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
hook.need_weight_init = False hook.need_weight_init = False
return hook_group return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
if model is None: if model is None:
return None return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict() patches_model: dict[str, torch.Tensor] = model.model.state_dict()
@ -560,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
return patches_model return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers # NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float): strength_model: float, strength_clip: float):
key_map = {} key_map = {}
if model is not None: if model is not None:
@ -612,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
else: else:
c_dict[hooks_key] = cache[hooks_tuple] c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
c = [] c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} if cache is None:
cache = {}
for t in conditioning: for t in conditioning:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
for k in values: for k in values:
if append_hooks and k == 'hooks': if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache) _combine_hooks_from_values(n[1], values, cache)
else: else:
n[1][k] = values[k] n[1][k] = values[k]
c.append(n) c.append(n)
return c return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
if hooks is None: if hooks is None:
return cond return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None: if timestep_range is None:
@ -664,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = [] final_conds = []
cache = {}
for c in conds: for c in conds:
# first, apply lora_hook to conditioning, if provided # first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to conditioning # next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present # apply timesteps, if present
@ -678,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = [] combined_conds = []
cache = {}
for c, masked_c in zip(conds, new_conds): for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided # first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to new conditioning, if provided # next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present # apply timesteps, if present
@ -692,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
def set_default_conds_and_combine(conds: list, new_conds: list, def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = [] combined_conds = []
cache = {}
for c, new_c in zip(conds, new_conds): for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided # first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
# next, add default_cond key to cond so that during sampling, it can be identified # next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True}) new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present # apply timesteps, if present

View File

@ -210,7 +210,7 @@ class ModelPatcher:
self.injections: dict[str, list[PatcherInjection]] = {} self.injections: dict[str, list[PatcherInjection]] = {}
self.hook_patches: dict[comfy.hooks._HookRef] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None self.current_hooks: Optional[comfy.hooks.HookGroup] = None
@ -282,7 +282,7 @@ class ModelPatcher:
n.injections[k] = i.copy() n.injections[k] = i.copy()
# hooks # hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches) n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
for group in self.cached_hook_patches: for group in self.cached_hook_patches:
n.cached_hook_patches[group] = {} n.cached_hook_patches[group] = {}
for k in self.cached_hook_patches[group]: for k in self.cached_hook_patches[group]:
@ -855,6 +855,9 @@ class ModelPatcher:
if key in self.injections: if key in self.injections:
self.injections.pop(key) self.injections.pop(key)
def get_injections(self, key: str):
return self.injections.get(key, None)
def set_additional_models(self, key: str, models: list['ModelPatcher']): def set_additional_models(self, key: str, models: list['ModelPatcher']):
self.additional_models[key] = models self.additional_models[key] = models
@ -925,9 +928,9 @@ class ModelPatcher:
callback(self, timestep) callback(self, timestep)
def restore_hook_patches(self): def restore_hook_patches(self):
if len(self.hook_patches_backup) > 0: if self.hook_patches_backup is not None:
self.hook_patches = self.hook_patches_backup self.hook_patches = self.hook_patches_backup
self.hook_patches_backup = {} self.hook_patches_backup = None
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode self.hook_mode = hook_mode
@ -953,25 +956,26 @@ class ModelPatcher:
if reset_current_hooks: if reset_current_hooks:
self.patch_hooks(None) self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
self.restore_hook_patches() self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = [] if registered is None:
# handle WrapperHooks, if model_options provided registered = comfy.hooks.HookGroup()
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
hook.add_hook_patches(self, model_options, target, registered_hooks)
# handle WeightHooks # handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches: if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook) weight_hooks_to_register.append(hook)
else:
registered.add(hook)
if len(weight_hooks_to_register) > 0: if len(weight_hooks_to_register) > 0:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register: for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target, registered_hooks) hook.add_hook_patches(self, model_options, target_dict, registered)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target) callback(self, hooks, target_dict, model_options, registered)
return registered
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected(): with self.use_ejected():
@ -1022,11 +1026,11 @@ class ModelPatcher:
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks # TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {} return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
self.patch_hooks(hooks=hooks) self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks) callback(self, hooks)
return {} return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected(): with self.use_ejected():

View File

@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]] models += [c[model_type]]
return models return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks # get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = [] cnets: list[ControlBase] = []
for c in cond: for c in cond:
if 'hooks' in c: if 'hooks' in c:
for hook in c['hooks'].hooks: for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook full_hooks.add(hook)
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
if 'control' in c: if 'control' in c:
cnets.append(c['control']) cnets.append(c['control'])
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None: if extra_hooks is not None:
for hook in extra_hooks.hooks: for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {}) full_hooks.add(hook)
with_type[hook] = None
return hooks_dict return full_hooks
def convert_cond(cond): def convert_cond(cond):
out = [] out = []
@ -73,13 +70,11 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = [] cnets: list[ControlBase] = []
gligen = [] gligen = []
add_models = [] add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds: for k in conds:
cnets += get_models_from_cond(conds[k], "control") cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen") gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models") add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets) control_nets = set(cnets)
@ -90,11 +85,20 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype) inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] models = control_models + gligen + add_models
models = control_models + gligen + add_models + hook_models
return models, inference_memory return models, inference_memory
def get_additional_models_from_model_options(model_options: dict[str]=None):
"""loads additional models from registered AddModels hooks"""
models = []
if model_options is not None and "registered_hooks" in model_options:
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
models.extend(hook.models)
return models
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" """cleanup additional models that were loaded"""
for m in models: for m in models:
@ -102,9 +106,10 @@ def cleanup_additional_models(models):
m.cleanup() m.cleanup()
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: 'BaseModel' = None real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
@ -123,12 +128,35 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup)) cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
'''
Registers hooks from conds.
'''
# check for hooks in conds - if not registered, see if can be applied # check for hooks in conds - if not registered, see if can be applied
hooks = {} hooks = comfy.hooks.HookGroup()
for k in conds: for k in conds:
get_hooks_from_cond(conds[k], hooks) get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options # add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options # begin registering hooks
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
# handle all TransformerOptionsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook: comfy.hooks.TransformerOptionsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all AddModelsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all WeightHooks by registering on ModelPatcher
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
# add registered_hooks onto model_options for further reference
if len(registered) > 0:
model_options["registered_hooks"] = registered
# merge original wrappers and callbacks with hooked wrappers and callbacks
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
for wc_name in ["wrappers", "callbacks"]:
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options

View File

@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
for cond in conds_to_modify: for cond in conds_to_modify:
cond['hooks'] = hooks cond['hooks'] = hooks
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
HookGroups that have the same reference.'''
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
# if None were registered, make sure all hooks are cleaned from conds
if registered is None:
for k in conds:
for kk in conds[k]:
kk.pop('hooks', None)
return
# find conds that contain hooks to be replaced - group by common HookGroup refs
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
for k in conds:
for kk in conds[k]:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
if hooks is not None:
if not hooks.is_subset_of(registered):
to_replace = hook_replacement.setdefault(hooks, [])
to_replace.append(kk)
# for each hook to replace, create a new proper HookGroup and assign to all common conds
for hooks, conds_to_modify in hook_replacement.items():
new_hooks = hooks.new_with_common_hooks(registered)
if len(new_hooks) == 0:
new_hooks = None
for kk in conds_to_modify:
kk['hooks'] = new_hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set() hooks_set = set()
@ -819,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
return len(hooks_set) return len(hooks_set)
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
'''
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
'''
if model_options is None:
return
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
casts = []
if device is not None:
casts.append(device)
if dtype is not None:
casts.append(dtype)
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
if hasattr(wc_list[i], "to"):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider: class CFGGuider:
def __init__(self, model_patcher): def __init__(self, model_patcher: ModelPatcher):
self.model_patcher: 'ModelPatcher' = model_patcher self.model_patcher = model_patcher
self.model_options = model_patcher.model_options self.model_options = model_patcher.model_options
self.original_conds = {} self.original_conds = {}
self.cfg = 1.0 self.cfg = 1.0
@ -861,7 +937,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32)) return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
if denoise_mask is not None: if denoise_mask is not None:
@ -870,6 +946,7 @@ class CFGGuider:
noise = noise.to(device) noise = noise.to(device)
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try: try:
self.model_patcher.pre_run() self.model_patcher.pre_run()
@ -899,6 +976,7 @@ class CFGGuider:
if get_total_hook_groups_in_conds(self.conds) <= 1: if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
filter_registered_hooks_on_conds(self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample, self.outer_sample,
self, self,
@ -906,6 +984,7 @@ class CFGGuider:
) )
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally: finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches() self.model_patcher.restore_hook_patches()

View File

@ -246,7 +246,7 @@ class SetClipHooks:
CATEGORY = "advanced/hooks/clip" CATEGORY = "advanced/hooks/clip"
FUNCTION = "apply_hooks" FUNCTION = "apply_hooks"
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None: if hooks is not None:
clip = clip.clone() clip = clip.clone()
if apply_to_conds: if apply_to_conds:
@ -255,7 +255,7 @@ class SetClipHooks:
clip.use_clip_schedule = schedule_clip clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule: if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None) clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
return (clip,) return (clip,)
class ConditioningTimestepsRange: class ConditioningTimestepsRange: