mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Compare commits
23 Commits
23f8622a82
...
c3e8262e5e
Author | SHA1 | Date | |
---|---|---|---|
|
c3e8262e5e | ||
|
ff838657fa | ||
|
7333281698 | ||
|
3cd4c5cb0a | ||
|
11c6d56037 | ||
|
216fea15ee | ||
|
58bf8815c8 | ||
|
1b38f5bf57 | ||
|
2724ac4a60 | ||
|
f48f90e471 | ||
|
6463c39ce0 | ||
|
0a7e2ae787 | ||
|
03a97b604a | ||
|
4446c86052 | ||
|
8270ff312f | ||
|
db2d7ad9ba | ||
|
6620d86318 | ||
|
111fd0cadf | ||
|
776aa734e1 | ||
|
5a2ad032cb | ||
|
d44295ef71 | ||
|
bf21be066f | ||
|
72bbf49349 |
359
comfy/hooks.py
359
comfy/hooks.py
@ -16,91 +16,132 @@ import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
# Hooks explanation
|
||||
# -------------------
|
||||
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
||||
# make explicit special cases like it does for ControlNet and GLIGEN.
|
||||
#
|
||||
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
||||
# that should run special code when a 'marked' cond is used in sampling.
|
||||
# #######################################################################################################
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
'''
|
||||
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
||||
|
||||
MinVram: No caching will occur for any operations related to hooks.
|
||||
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
||||
'''
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
'''
|
||||
Hook types, each of which has different expected behavior.
|
||||
'''
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
AdditionalModels = "add_models"
|
||||
TransformerOptions = "transformer_options"
|
||||
Injections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class EnumHookScope(enum.Enum):
|
||||
'''
|
||||
Determines if hook should be limited in its influence over sampling.
|
||||
|
||||
AllConditioning: hook will affect all conds used in sampling.
|
||||
HookedOnly: hook will only affect the conds it was attached to.
|
||||
'''
|
||||
AllConditioning = "all_conditioning"
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
'''Example for how custom_should_register function can look like.'''
|
||||
return True
|
||||
|
||||
|
||||
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
||||
'''Creates base dictionary for use with Hooks' target param.'''
|
||||
d = {}
|
||||
if target is not None:
|
||||
d['target'] = target
|
||||
d.update(kwargs)
|
||||
return d
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||
self.hook_type = hook_type
|
||||
'''Enum identifying the general class of this hook.'''
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||
self.hook_id = hook_id
|
||||
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||
self.hook_scope = hook_scope
|
||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
def initialize_timesteps(self, model: BaseModel):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
def clone(self):
|
||||
c: Hook = self.__class__()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.hook_scope = self.hook_scope
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
def __eq__(self, other: Hook):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
'''
|
||||
Hook responsible for tracking weights to be applied to some model/clip.
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
||||
'''
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
@ -110,36 +151,36 @@ class WeightHook(Hook):
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
|
||||
target = target_dict.get('target', None)
|
||||
if target == EnumWeightTarget.Clip:
|
||||
strength = self._strength_clip
|
||||
else:
|
||||
strength = self._strength_model
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
weights = self.weights_clip
|
||||
else:
|
||||
weights = self.weights
|
||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
registered.add(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: WeightHook = super().clone()
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
@ -147,127 +188,158 @@ class WeightHook(Hook):
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
def __init__(self, object_patches: dict[str]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
self.object_patches = object_patches
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: ObjectPatchHook = super().clone()
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
||||
|
||||
class AdditionalModelsHook(Hook):
|
||||
'''
|
||||
Hook responsible for telling model management any additional models that should be loaded.
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||
'''
|
||||
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: AdditionalModelsHook = super().clone()
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
registered.add(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
class TransformerOptionsHook(Hook):
|
||||
'''
|
||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||
'''
|
||||
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||
self.transformers_dict = transformers_dict
|
||||
self.hook_scope = hook_scope
|
||||
self._skip_adding = False
|
||||
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||
|
||||
def clone(self):
|
||||
c: TransformerOptionsHook = super().clone()
|
||||
c.transformers_dict = self.transformers_dict
|
||||
c._skip_adding = self._skip_adding
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||
self._skip_adding = False
|
||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||
add_model_options = {"transformer_options": self.transformers_dict,
|
||||
"to_load_options": self.transformers_dict}
|
||||
# skip_adding if included in AllConditioning to avoid double loading
|
||||
self._skip_adding = True
|
||||
else:
|
||||
add_model_options = {"to_load_options": self.transformers_dict}
|
||||
registered.add(self)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
return True
|
||||
|
||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
if not self._skip_adding:
|
||||
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||
|
||||
WrapperHook = TransformerOptionsHook
|
||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||
|
||||
class InjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||
hook_scope=EnumHookScope.AllConditioning):
|
||||
super().__init__(hook_type=EnumHookType.Injections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
self.hook_scope = hook_scope
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
def clone(self):
|
||||
c: InjectionsHook = super().clone()
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
||||
|
||||
class HookGroup:
|
||||
'''
|
||||
Stores groups of hooks, and allows them to be queried by type.
|
||||
|
||||
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||
always use the provided functions on HookGroup.
|
||||
'''
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hooks)
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||
|
||||
def remove(self, hook: Hook):
|
||||
if hook in self.hooks:
|
||||
self.hooks.remove(hook)
|
||||
self._hook_dict[hook.hook_type].remove(hook)
|
||||
|
||||
def get_type(self, hook_type: EnumHookType):
|
||||
return self._hook_dict.get(hook_type, [])
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
def is_subset_of(self, other: HookGroup):
|
||||
self_hooks = set(self.hooks)
|
||||
other_hooks = set(other.hooks)
|
||||
return self_hooks.issubset(other_hooks)
|
||||
|
||||
def new_with_common_hooks(self, other: HookGroup):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
if other.contains(hook):
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
def clone_and_combine(self, other: HookGroup):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
@ -275,18 +347,11 @@ class HookGroup:
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
for hook in self.get_type(EnumHookType.Weight):
|
||||
hook: WeightHook
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
@ -336,7 +401,7 @@ class HookGroup:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
@ -433,7 +498,7 @@ class HookKeyframeGroup:
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
def initialize_timesteps(self, model: BaseModel):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
@ -522,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||
if hooks is None or model.is_clip:
|
||||
return {}
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||
hook: TransformerOptionsHook
|
||||
hook.on_apply_hooks(model, transformer_options)
|
||||
return transformer_options
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
@ -548,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
@ -560,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@ -612,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
||||
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
if cache is None:
|
||||
cache = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
_combine_hooks_from_values(n[1], values, cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
@ -664,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
cache = {}
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
@ -678,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
cache = {}
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
@ -692,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
cache = {}
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
|
@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
|
@ -210,7 +210,7 @@ class ModelPatcher:
|
||||
self.injections: dict[str, list[PatcherInjection]] = {}
|
||||
|
||||
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
|
||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
||||
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||
@ -282,7 +282,7 @@ class ModelPatcher:
|
||||
n.injections[k] = i.copy()
|
||||
# hooks
|
||||
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
||||
for group in self.cached_hook_patches:
|
||||
n.cached_hook_patches[group] = {}
|
||||
for k in self.cached_hook_patches[group]:
|
||||
@ -855,6 +855,9 @@ class ModelPatcher:
|
||||
if key in self.injections:
|
||||
self.injections.pop(key)
|
||||
|
||||
def get_injections(self, key: str):
|
||||
return self.injections.get(key, None)
|
||||
|
||||
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
||||
self.additional_models[key] = models
|
||||
|
||||
@ -925,9 +928,9 @@ class ModelPatcher:
|
||||
callback(self, timestep)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if len(self.hook_patches_backup) > 0:
|
||||
if self.hook_patches_backup is not None:
|
||||
self.hook_patches = self.hook_patches_backup
|
||||
self.hook_patches_backup = {}
|
||||
self.hook_patches_backup = None
|
||||
|
||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||
self.hook_mode = hook_mode
|
||||
@ -953,25 +956,26 @@ class ModelPatcher:
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
self.restore_hook_patches()
|
||||
registered_hooks: list[comfy.hooks.Hook] = []
|
||||
# handle WrapperHooks, if model_options provided
|
||||
if model_options is not None:
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
||||
if registered is None:
|
||||
registered = comfy.hooks.HookGroup()
|
||||
# handle WeightHooks
|
||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||
if hook.hook_ref not in self.hook_patches:
|
||||
weight_hooks_to_register.append(hook)
|
||||
else:
|
||||
registered.add(hook)
|
||||
if len(weight_hooks_to_register) > 0:
|
||||
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||
for hook in weight_hooks_to_register:
|
||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
||||
hook.add_hook_patches(self, model_options, target_dict, registered)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||
callback(self, hooks_dict, target)
|
||||
callback(self, hooks, target_dict, model_options, registered)
|
||||
return registered
|
||||
|
||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||
with self.use_ejected():
|
||||
@ -1022,11 +1026,11 @@ class ModelPatcher:
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||
# TODO: return transformer_options dict with any additions from hooks
|
||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||
return {}
|
||||
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||
self.patch_hooks(hooks=hooks)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||
callback(self, hooks)
|
||||
return {}
|
||||
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
|
@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
|
||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||
cnets: list[ControlBase] = []
|
||||
for c in cond:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook: comfy.hooks.Hook
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
full_hooks.add(hook)
|
||||
if 'control' in c:
|
||||
cnets.append(c['control'])
|
||||
|
||||
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||
if extra_hooks is not None:
|
||||
for hook in extra_hooks.hooks:
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
full_hooks.add(hook)
|
||||
|
||||
return hooks_dict
|
||||
return full_hooks
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
@ -73,13 +70,11 @@ def get_additional_models(conds, dtype):
|
||||
cnets: list[ControlBase] = []
|
||||
gligen = []
|
||||
add_models = []
|
||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
@ -90,11 +85,20 @@ def get_additional_models(conds, dtype):
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||
models = control_models + gligen + add_models + hook_models
|
||||
models = control_models + gligen + add_models
|
||||
|
||||
return models, inference_memory
|
||||
|
||||
def get_additional_models_from_model_options(model_options: dict[str]=None):
|
||||
"""loads additional models from registered AddModels hooks"""
|
||||
models = []
|
||||
if model_options is not None and "registered_hooks" in model_options:
|
||||
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
|
||||
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||
hook: comfy.hooks.AdditionalModelsHook
|
||||
models.extend(hook.models)
|
||||
return models
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
@ -102,9 +106,10 @@ def cleanup_additional_models(models):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
real_model: 'BaseModel' = None
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
@ -123,12 +128,35 @@ def cleanup_models(conds, models):
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
# check for hooks in conds - if not registered, see if can be applied
|
||||
hooks = {}
|
||||
hooks = comfy.hooks.HookGroup()
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
# register hooks on model/model_options
|
||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||
# begin registering hooks
|
||||
registered = comfy.hooks.HookGroup()
|
||||
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||
# handle all TransformerOptionsHooks
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||
hook: comfy.hooks.TransformerOptionsHook
|
||||
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||
# handle all AddModelsHooks
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||
hook: comfy.hooks.AdditionalModelsHook
|
||||
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||
# handle all WeightHooks by registering on ModelPatcher
|
||||
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
|
||||
# add registered_hooks onto model_options for further reference
|
||||
if len(registered) > 0:
|
||||
model_options["registered_hooks"] = registered
|
||||
# merge original wrappers and callbacks with hooked wrappers and callbacks
|
||||
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
|
||||
for wc_name in ["wrappers", "callbacks"]:
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
|
||||
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
|
||||
HookGroups that have the same reference.'''
|
||||
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
|
||||
# if None were registered, make sure all hooks are cleaned from conds
|
||||
if registered is None:
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
kk.pop('hooks', None)
|
||||
return
|
||||
# find conds that contain hooks to be replaced - group by common HookGroup refs
|
||||
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
if hooks is not None:
|
||||
if not hooks.is_subset_of(registered):
|
||||
to_replace = hook_replacement.setdefault(hooks, [])
|
||||
to_replace.append(kk)
|
||||
# for each hook to replace, create a new proper HookGroup and assign to all common conds
|
||||
for hooks, conds_to_modify in hook_replacement.items():
|
||||
new_hooks = hooks.new_with_common_hooks(registered)
|
||||
if len(new_hooks) == 0:
|
||||
new_hooks = None
|
||||
for kk in conds_to_modify:
|
||||
kk['hooks'] = new_hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
@ -819,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
'''
|
||||
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
||||
'''
|
||||
if model_options is None:
|
||||
return
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
if dtype is not None:
|
||||
casts.append(dtype)
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
for cast in casts:
|
||||
patch_list[k] = patch_list[k].to(cast)
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
if hasattr(wc_list[i], "to"):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
@ -861,7 +937,7 @@ class CFGGuider:
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
@ -870,6 +946,7 @@ class CFGGuider:
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
@ -899,6 +976,7 @@ class CFGGuider:
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
filter_registered_hooks_on_conds(self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
self,
|
||||
@ -906,6 +984,7 @@ class CFGGuider:
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
@ -246,7 +246,7 @@ class SetClipHooks:
|
||||
CATEGORY = "advanced/hooks/clip"
|
||||
FUNCTION = "apply_hooks"
|
||||
|
||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
if apply_to_conds:
|
||||
@ -255,7 +255,7 @@ class SetClipHooks:
|
||||
clip.use_clip_schedule = schedule_clip
|
||||
if not clip.use_clip_schedule:
|
||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
|
Loading…
Reference in New Issue
Block a user