mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Compare commits
23 Commits
de7d85d894
...
23f8622a82
Author | SHA1 | Date | |
---|---|---|---|
|
23f8622a82 | ||
|
2307ff6746 | ||
|
7333281698 | ||
|
3cd4c5cb0a | ||
|
11c6d56037 | ||
|
216fea15ee | ||
|
58bf8815c8 | ||
|
1b38f5bf57 | ||
|
2724ac4a60 | ||
|
f48f90e471 | ||
|
6463c39ce0 | ||
|
0a7e2ae787 | ||
|
03a97b604a | ||
|
4446c86052 | ||
|
8270ff312f | ||
|
db2d7ad9ba | ||
|
6620d86318 | ||
|
111fd0cadf | ||
|
776aa734e1 | ||
|
5a2ad032cb | ||
|
d44295ef71 | ||
|
bf21be066f | ||
|
72bbf49349 |
395
comfy/hooks.py
395
comfy/hooks.py
@ -16,91 +16,132 @@ import comfy.model_management
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from node_helpers import conditioning_set_values
|
from node_helpers import conditioning_set_values
|
||||||
|
|
||||||
|
# #######################################################################################################
|
||||||
|
# Hooks explanation
|
||||||
|
# -------------------
|
||||||
|
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
||||||
|
# make explicit special cases like it does for ControlNet and GLIGEN.
|
||||||
|
#
|
||||||
|
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
||||||
|
# that should run special code when a 'marked' cond is used in sampling.
|
||||||
|
# #######################################################################################################
|
||||||
|
|
||||||
class EnumHookMode(enum.Enum):
|
class EnumHookMode(enum.Enum):
|
||||||
|
'''
|
||||||
|
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
||||||
|
|
||||||
|
MinVram: No caching will occur for any operations related to hooks.
|
||||||
|
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
||||||
|
'''
|
||||||
MinVram = "minvram"
|
MinVram = "minvram"
|
||||||
MaxSpeed = "maxspeed"
|
MaxSpeed = "maxspeed"
|
||||||
|
|
||||||
class EnumHookType(enum.Enum):
|
class EnumHookType(enum.Enum):
|
||||||
|
'''
|
||||||
|
Hook types, each of which has different expected behavior.
|
||||||
|
'''
|
||||||
Weight = "weight"
|
Weight = "weight"
|
||||||
Patch = "patch"
|
|
||||||
ObjectPatch = "object_patch"
|
ObjectPatch = "object_patch"
|
||||||
AddModels = "add_models"
|
AdditionalModels = "add_models"
|
||||||
Callbacks = "callbacks"
|
TransformerOptions = "transformer_options"
|
||||||
Wrappers = "wrappers"
|
Injections = "add_injections"
|
||||||
SetInjections = "add_injections"
|
|
||||||
|
|
||||||
class EnumWeightTarget(enum.Enum):
|
class EnumWeightTarget(enum.Enum):
|
||||||
Model = "model"
|
Model = "model"
|
||||||
Clip = "clip"
|
Clip = "clip"
|
||||||
|
|
||||||
|
class EnumHookScope(enum.Enum):
|
||||||
|
'''
|
||||||
|
Determines if hook should be limited in its influence over sampling.
|
||||||
|
|
||||||
|
AllConditioning: hook will affect all conds used in sampling.
|
||||||
|
HookedOnly: hook will only affect the conds it was attached to.
|
||||||
|
'''
|
||||||
|
AllConditioning = "all_conditioning"
|
||||||
|
HookedOnly = "hooked_only"
|
||||||
|
|
||||||
|
|
||||||
class _HookRef:
|
class _HookRef:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# NOTE: this is an example of how the should_register function should look
|
|
||||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
|
'''Example for how custom_should_register function can look like.'''
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
||||||
|
'''Creates base dictionary for use with Hooks' target param.'''
|
||||||
|
d = {}
|
||||||
|
if target is not None:
|
||||||
|
d['target'] = target
|
||||||
|
d.update(kwargs)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class Hook:
|
class Hook:
|
||||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
||||||
self.hook_type = hook_type
|
self.hook_type = hook_type
|
||||||
|
'''Enum identifying the general class of this hook.'''
|
||||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||||
|
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
||||||
self.hook_id = hook_id
|
self.hook_id = hook_id
|
||||||
|
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
||||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||||
|
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
self.auto_apply_to_nonpositive = False
|
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
return self.hook_keyframe.strength
|
return self.hook_keyframe.strength
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: BaseModel):
|
||||||
self.reset()
|
self.reset()
|
||||||
self.hook_keyframe.initialize_timesteps(model)
|
self.hook_keyframe.initialize_timesteps(model)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.hook_keyframe.reset()
|
self.hook_keyframe.reset()
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: Hook = self.__class__()
|
||||||
subtype = type(self)
|
|
||||||
c: Hook = subtype()
|
|
||||||
c.hook_type = self.hook_type
|
c.hook_type = self.hook_type
|
||||||
c.hook_ref = self.hook_ref
|
c.hook_ref = self.hook_ref
|
||||||
c.hook_id = self.hook_id
|
c.hook_id = self.hook_id
|
||||||
c.hook_keyframe = self.hook_keyframe
|
c.hook_keyframe = self.hook_keyframe
|
||||||
|
c.hook_scope = self.hook_scope
|
||||||
c.custom_should_register = self.custom_should_register
|
c.custom_should_register = self.custom_should_register
|
||||||
# TODO: make this do something
|
|
||||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
return self.custom_should_register(self, model, model_options, target, registered)
|
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||||
|
|
||||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
def __eq__(self, other: Hook):
|
||||||
pass
|
|
||||||
|
|
||||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __eq__(self, other: 'Hook'):
|
|
||||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.hook_ref)
|
return hash(self.hook_ref)
|
||||||
|
|
||||||
class WeightHook(Hook):
|
class WeightHook(Hook):
|
||||||
|
'''
|
||||||
|
Hook responsible for tracking weights to be applied to some model/clip.
|
||||||
|
|
||||||
|
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
||||||
|
'''
|
||||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||||
super().__init__(hook_type=EnumHookType.Weight)
|
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
||||||
self.weights: dict = None
|
self.weights: dict = None
|
||||||
self.weights_clip: dict = None
|
self.weights_clip: dict = None
|
||||||
self.need_weight_init = True
|
self.need_weight_init = True
|
||||||
self._strength_model = strength_model
|
self._strength_model = strength_model
|
||||||
self._strength_clip = strength_clip
|
self._strength_clip = strength_clip
|
||||||
|
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength_model(self):
|
def strength_model(self):
|
||||||
@ -110,36 +151,36 @@ class WeightHook(Hook):
|
|||||||
def strength_clip(self):
|
def strength_clip(self):
|
||||||
return self._strength_clip * self.strength
|
return self._strength_clip * self.strength
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
weights = None
|
weights = None
|
||||||
if target == EnumWeightTarget.Model:
|
|
||||||
strength = self._strength_model
|
target = target_dict.get('target', None)
|
||||||
else:
|
if target == EnumWeightTarget.Clip:
|
||||||
strength = self._strength_clip
|
strength = self._strength_clip
|
||||||
|
else:
|
||||||
|
strength = self._strength_model
|
||||||
|
|
||||||
if self.need_weight_init:
|
if self.need_weight_init:
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Clip:
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
|
||||||
else:
|
|
||||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||||
|
else:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||||
else:
|
else:
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Clip:
|
||||||
weights = self.weights
|
|
||||||
else:
|
|
||||||
weights = self.weights_clip
|
weights = self.weights_clip
|
||||||
|
else:
|
||||||
|
weights = self.weights
|
||||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||||
registered.append(self)
|
registered.add(self)
|
||||||
return True
|
return True
|
||||||
# TODO: add logs about any keys that were not applied
|
# TODO: add logs about any keys that were not applied
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: WeightHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: WeightHook = super().clone(subtype)
|
|
||||||
c.weights = self.weights
|
c.weights = self.weights
|
||||||
c.weights_clip = self.weights_clip
|
c.weights_clip = self.weights_clip
|
||||||
c.need_weight_init = self.need_weight_init
|
c.need_weight_init = self.need_weight_init
|
||||||
@ -147,127 +188,158 @@ class WeightHook(Hook):
|
|||||||
c._strength_clip = self._strength_clip
|
c._strength_clip = self._strength_clip
|
||||||
return c
|
return c
|
||||||
|
|
||||||
class PatchHook(Hook):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(hook_type=EnumHookType.Patch)
|
|
||||||
self.patches: dict = None
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: PatchHook = super().clone(subtype)
|
|
||||||
c.patches = self.patches
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class ObjectPatchHook(Hook):
|
class ObjectPatchHook(Hook):
|
||||||
def __init__(self):
|
def __init__(self, object_patches: dict[str]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
self.object_patches: dict = None
|
self.object_patches = object_patches
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: ObjectPatchHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: ObjectPatchHook = super().clone(subtype)
|
|
||||||
c.object_patches = self.object_patches
|
c.object_patches = self.object_patches
|
||||||
return c
|
return c
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class AddModelsHook(Hook):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
||||||
super().__init__(hook_type=EnumHookType.AddModels)
|
|
||||||
self.key = key
|
class AdditionalModelsHook(Hook):
|
||||||
|
'''
|
||||||
|
Hook responsible for telling model management any additional models that should be loaded.
|
||||||
|
|
||||||
|
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||||
|
'''
|
||||||
|
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||||
|
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
||||||
self.models = models
|
self.models = models
|
||||||
self.append_when_same = True
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: AddModelsHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
|
||||||
c.models = self.models.copy() if self.models else self.models
|
|
||||||
c.append_when_same = self.append_when_same
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class CallbackHook(Hook):
|
|
||||||
def __init__(self, key: str=None, callback: Callable=None):
|
|
||||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
|
||||||
self.key = key
|
self.key = key
|
||||||
self.callback = callback
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: AdditionalModelsHook = super().clone()
|
||||||
subtype = type(self)
|
c.models = self.models.copy() if self.models else self.models
|
||||||
c: CallbackHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.callback = self.callback
|
|
||||||
return c
|
|
||||||
# TODO: add functionality
|
|
||||||
|
|
||||||
class WrapperHook(Hook):
|
|
||||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
|
||||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
|
||||||
self.wrappers_dict = wrappers_dict
|
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
|
||||||
if subtype is None:
|
|
||||||
subtype = type(self)
|
|
||||||
c: WrapperHook = super().clone(subtype)
|
|
||||||
c.wrappers_dict = self.wrappers_dict
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
registered.add(self)
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
|
||||||
registered.append(self)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class SetInjectionsHook(Hook):
|
class TransformerOptionsHook(Hook):
|
||||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
'''
|
||||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
||||||
|
'''
|
||||||
|
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
|
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
||||||
|
self.transformers_dict = transformers_dict
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
self._skip_adding = False
|
||||||
|
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c: TransformerOptionsHook = super().clone()
|
||||||
|
c.transformers_dict = self.transformers_dict
|
||||||
|
c._skip_adding = self._skip_adding
|
||||||
|
return c
|
||||||
|
|
||||||
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
|
return False
|
||||||
|
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||||
|
self._skip_adding = False
|
||||||
|
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||||
|
add_model_options = {"transformer_options": self.transformers_dict,
|
||||||
|
"to_load_options": self.transformers_dict}
|
||||||
|
# skip_adding if included in AllConditioning to avoid double loading
|
||||||
|
self._skip_adding = True
|
||||||
|
else:
|
||||||
|
add_model_options = {"to_load_options": self.transformers_dict}
|
||||||
|
registered.add(self)
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
|
if not self._skip_adding:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
||||||
|
|
||||||
|
WrapperHook = TransformerOptionsHook
|
||||||
|
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
||||||
|
|
||||||
|
class InjectionsHook(Hook):
|
||||||
|
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
||||||
|
hook_scope=EnumHookScope.AllConditioning):
|
||||||
|
super().__init__(hook_type=EnumHookType.Injections)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.injections = injections
|
self.injections = injections
|
||||||
|
self.hook_scope = hook_scope
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self):
|
||||||
if subtype is None:
|
c: InjectionsHook = super().clone()
|
||||||
subtype = type(self)
|
|
||||||
c: SetInjectionsHook = super().clone(subtype)
|
|
||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.injections = self.injections.copy() if self.injections else self.injections
|
c.injections = self.injections.copy() if self.injections else self.injections
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
# TODO: add functionality
|
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
||||||
pass
|
|
||||||
|
|
||||||
class HookGroup:
|
class HookGroup:
|
||||||
|
'''
|
||||||
|
Stores groups of hooks, and allows them to be queried by type.
|
||||||
|
|
||||||
|
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
||||||
|
always use the provided functions on HookGroup.
|
||||||
|
'''
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hooks: list[Hook] = []
|
self.hooks: list[Hook] = []
|
||||||
|
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.hooks)
|
||||||
|
|
||||||
def add(self, hook: Hook):
|
def add(self, hook: Hook):
|
||||||
if hook not in self.hooks:
|
if hook not in self.hooks:
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
||||||
|
|
||||||
|
def remove(self, hook: Hook):
|
||||||
|
if hook in self.hooks:
|
||||||
|
self.hooks.remove(hook)
|
||||||
|
self._hook_dict[hook.hook_type].remove(hook)
|
||||||
|
|
||||||
|
def get_type(self, hook_type: EnumHookType):
|
||||||
|
return self._hook_dict.get(hook_type, [])
|
||||||
|
|
||||||
def contains(self, hook: Hook):
|
def contains(self, hook: Hook):
|
||||||
return hook in self.hooks
|
return hook in self.hooks
|
||||||
|
|
||||||
|
def is_subset_of(self, other: HookGroup):
|
||||||
|
self_hooks = set(self.hooks)
|
||||||
|
other_hooks = set(other.hooks)
|
||||||
|
return self_hooks.issubset(other_hooks)
|
||||||
|
|
||||||
|
def new_with_common_hooks(self, other: HookGroup):
|
||||||
|
c = HookGroup()
|
||||||
|
for hook in self.hooks:
|
||||||
|
if other.contains(hook):
|
||||||
|
c.add(hook.clone())
|
||||||
|
return c
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookGroup()
|
c = HookGroup()
|
||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def clone_and_combine(self, other: 'HookGroup'):
|
def clone_and_combine(self, other: HookGroup):
|
||||||
c = self.clone()
|
c = self.clone()
|
||||||
if other is not None:
|
if other is not None:
|
||||||
for hook in other.hooks:
|
for hook in other.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
||||||
if hook_kf is None:
|
if hook_kf is None:
|
||||||
hook_kf = HookKeyframeGroup()
|
hook_kf = HookKeyframeGroup()
|
||||||
else:
|
else:
|
||||||
@ -275,36 +347,29 @@ class HookGroup:
|
|||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
hook.hook_keyframe = hook_kf
|
hook.hook_keyframe = hook_kf
|
||||||
|
|
||||||
def get_dict_repr(self):
|
|
||||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
|
||||||
for hook in self.hooks:
|
|
||||||
with_type = d.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_hooks_for_clip_schedule(self):
|
def get_hooks_for_clip_schedule(self):
|
||||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||||
for hook in self.hooks:
|
# only care about WeightHooks, for now
|
||||||
# only care about WeightHooks, for now
|
for hook in self.get_type(EnumHookType.Weight):
|
||||||
if hook.hook_type == EnumHookType.Weight:
|
hook: WeightHook
|
||||||
hook_schedule = []
|
hook_schedule = []
|
||||||
# if no hook keyframes, assign default value
|
# if no hook keyframes, assign default value
|
||||||
if len(hook.hook_keyframe.keyframes) == 0:
|
if len(hook.hook_keyframe.keyframes) == 0:
|
||||||
hook_schedule.append(((0.0, 1.0), None))
|
hook_schedule.append(((0.0, 1.0), None))
|
||||||
scheduled_hooks[hook] = hook_schedule
|
|
||||||
continue
|
|
||||||
# find ranges of values
|
|
||||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
|
||||||
for keyframe in hook.hook_keyframe.keyframes:
|
|
||||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
# create final range, assuming last start_percent was not 1.0
|
|
||||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
|
||||||
scheduled_hooks[hook] = hook_schedule
|
scheduled_hooks[hook] = hook_schedule
|
||||||
|
continue
|
||||||
|
# find ranges of values
|
||||||
|
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||||
|
for keyframe in hook.hook_keyframe.keyframes:
|
||||||
|
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||||
|
prev_keyframe = keyframe
|
||||||
|
# create final range, assuming last start_percent was not 1.0
|
||||||
|
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||||
|
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||||
|
scheduled_hooks[hook] = hook_schedule
|
||||||
# hooks should not have their schedules in a list of tuples
|
# hooks should not have their schedules in a list of tuples
|
||||||
all_ranges: list[tuple[float, float]] = []
|
all_ranges: list[tuple[float, float]] = []
|
||||||
for range_kfs in scheduled_hooks.values():
|
for range_kfs in scheduled_hooks.values():
|
||||||
@ -336,7 +401,7 @@ class HookGroup:
|
|||||||
hook.reset()
|
hook.reset()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
||||||
actual: list[HookGroup] = []
|
actual: list[HookGroup] = []
|
||||||
for group in hooks_list:
|
for group in hooks_list:
|
||||||
if group is not None:
|
if group is not None:
|
||||||
@ -433,7 +498,7 @@ class HookKeyframeGroup:
|
|||||||
c._set_first_as_current()
|
c._set_first_as_current()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: BaseModel):
|
||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||||
|
|
||||||
@ -522,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
|||||||
sorted_list.extend(object_list)
|
sorted_list.extend(object_list)
|
||||||
return sorted_list
|
return sorted_list
|
||||||
|
|
||||||
|
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
||||||
|
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
||||||
|
if hooks is None or model.is_clip:
|
||||||
|
return {}
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
||||||
|
hook: TransformerOptionsHook
|
||||||
|
hook.on_apply_hooks(model, transformer_options)
|
||||||
|
return transformer_options
|
||||||
|
|
||||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||||
hook_group = HookGroup()
|
hook_group = HookGroup()
|
||||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||||
@ -548,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
|
|||||||
hook.need_weight_init = False
|
hook.need_weight_init = False
|
||||||
return hook_group
|
return hook_group
|
||||||
|
|
||||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
||||||
if model is None:
|
if model is None:
|
||||||
return None
|
return None
|
||||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||||
@ -560,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
|
|||||||
return patches_model
|
return patches_model
|
||||||
|
|
||||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
||||||
strength_model: float, strength_clip: float):
|
strength_model: float, strength_clip: float):
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if model is not None:
|
if model is not None:
|
||||||
@ -612,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
|||||||
else:
|
else:
|
||||||
c_dict[hooks_key] = cache[hooks_tuple]
|
c_dict[hooks_key] = cache[hooks_tuple]
|
||||||
|
|
||||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
||||||
|
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||||
c = []
|
c = []
|
||||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
if cache is None:
|
||||||
|
cache = {}
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
for k in values:
|
for k in values:
|
||||||
if append_hooks and k == 'hooks':
|
if append_hooks and k == 'hooks':
|
||||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
_combine_hooks_from_values(n[1], values, cache)
|
||||||
else:
|
else:
|
||||||
n[1][k] = values[k]
|
n[1][k] = values[k]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
||||||
if hooks is None:
|
if hooks is None:
|
||||||
return cond
|
return cond
|
||||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
||||||
|
|
||||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||||
if timestep_range is None:
|
if timestep_range is None:
|
||||||
@ -664,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
|
|||||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
final_conds = []
|
final_conds = []
|
||||||
|
cache = {}
|
||||||
for c in conds:
|
for c in conds:
|
||||||
# first, apply lora_hook to conditioning, if provided
|
# first, apply lora_hook to conditioning, if provided
|
||||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, apply mask to conditioning
|
# next, apply mask to conditioning
|
||||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
@ -678,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
|||||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
combined_conds = []
|
combined_conds = []
|
||||||
|
cache = {}
|
||||||
for c, masked_c in zip(conds, new_conds):
|
for c, masked_c in zip(conds, new_conds):
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, apply mask to new conditioning, if provided
|
# next, apply mask to new conditioning, if provided
|
||||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
@ -692,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
|
|||||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||||
combined_conds = []
|
combined_conds = []
|
||||||
|
cache = {}
|
||||||
for c, new_c in zip(conds, new_conds):
|
for c, new_c in zip(conds, new_conds):
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
# first, apply lora_hook to new conditioning, if provided
|
||||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
||||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||||
new_c = conditioning_set_values(new_c, {'default': True})
|
new_c = conditioning_set_values(new_c, {'default': True})
|
||||||
# apply timesteps, if present
|
# apply timesteps, if present
|
||||||
|
@ -210,7 +210,7 @@ class ModelPatcher:
|
|||||||
self.injections: dict[str, list[PatcherInjection]] = {}
|
self.injections: dict[str, list[PatcherInjection]] = {}
|
||||||
|
|
||||||
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
||||||
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
|
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
||||||
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
||||||
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
||||||
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
||||||
@ -282,7 +282,7 @@ class ModelPatcher:
|
|||||||
n.injections[k] = i.copy()
|
n.injections[k] = i.copy()
|
||||||
# hooks
|
# hooks
|
||||||
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
||||||
for group in self.cached_hook_patches:
|
for group in self.cached_hook_patches:
|
||||||
n.cached_hook_patches[group] = {}
|
n.cached_hook_patches[group] = {}
|
||||||
for k in self.cached_hook_patches[group]:
|
for k in self.cached_hook_patches[group]:
|
||||||
@ -855,6 +855,9 @@ class ModelPatcher:
|
|||||||
if key in self.injections:
|
if key in self.injections:
|
||||||
self.injections.pop(key)
|
self.injections.pop(key)
|
||||||
|
|
||||||
|
def get_injections(self, key: str):
|
||||||
|
return self.injections.get(key, None)
|
||||||
|
|
||||||
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
||||||
self.additional_models[key] = models
|
self.additional_models[key] = models
|
||||||
|
|
||||||
@ -925,9 +928,9 @@ class ModelPatcher:
|
|||||||
callback(self, timestep)
|
callback(self, timestep)
|
||||||
|
|
||||||
def restore_hook_patches(self):
|
def restore_hook_patches(self):
|
||||||
if len(self.hook_patches_backup) > 0:
|
if self.hook_patches_backup is not None:
|
||||||
self.hook_patches = self.hook_patches_backup
|
self.hook_patches = self.hook_patches_backup
|
||||||
self.hook_patches_backup = {}
|
self.hook_patches_backup = None
|
||||||
|
|
||||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||||
self.hook_mode = hook_mode
|
self.hook_mode = hook_mode
|
||||||
@ -953,25 +956,26 @@ class ModelPatcher:
|
|||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
|
registered: comfy.hooks.HookGroup = None):
|
||||||
self.restore_hook_patches()
|
self.restore_hook_patches()
|
||||||
registered_hooks: list[comfy.hooks.Hook] = []
|
if registered is None:
|
||||||
# handle WrapperHooks, if model_options provided
|
registered = comfy.hooks.HookGroup()
|
||||||
if model_options is not None:
|
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
|
|
||||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
|
||||||
# handle WeightHooks
|
# handle WeightHooks
|
||||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||||
if hook.hook_ref not in self.hook_patches:
|
if hook.hook_ref not in self.hook_patches:
|
||||||
weight_hooks_to_register.append(hook)
|
weight_hooks_to_register.append(hook)
|
||||||
|
else:
|
||||||
|
registered.add(hook)
|
||||||
if len(weight_hooks_to_register) > 0:
|
if len(weight_hooks_to_register) > 0:
|
||||||
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||||
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||||
for hook in weight_hooks_to_register:
|
for hook in weight_hooks_to_register:
|
||||||
hook.add_hook_patches(self, model_options, target, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||||
callback(self, hooks_dict, target)
|
callback(self, hooks, target_dict, model_options, registered)
|
||||||
|
return registered
|
||||||
|
|
||||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
@ -1022,11 +1026,11 @@ class ModelPatcher:
|
|||||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
||||||
# TODO: return transformer_options dict with any additions from hooks
|
# TODO: return transformer_options dict with any additions from hooks
|
||||||
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
self.patch_hooks(hooks=hooks)
|
self.patch_hooks(hooks=hooks)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||||
callback(self, hooks)
|
callback(self, hooks)
|
||||||
return {}
|
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[model_type]]
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
|
||||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if 'hooks' in c:
|
if 'hooks' in c:
|
||||||
for hook in c['hooks'].hooks:
|
for hook in c['hooks'].hooks:
|
||||||
hook: comfy.hooks.Hook
|
full_hooks.add(hook)
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
|
||||||
with_type[hook] = None
|
|
||||||
if 'control' in c:
|
if 'control' in c:
|
||||||
cnets.append(c['control'])
|
cnets.append(c['control'])
|
||||||
|
|
||||||
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
|
|||||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||||
if extra_hooks is not None:
|
if extra_hooks is not None:
|
||||||
for hook in extra_hooks.hooks:
|
for hook in extra_hooks.hooks:
|
||||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
full_hooks.add(hook)
|
||||||
with_type[hook] = None
|
|
||||||
|
|
||||||
return hooks_dict
|
return full_hooks
|
||||||
|
|
||||||
def convert_cond(cond):
|
def convert_cond(cond):
|
||||||
out = []
|
out = []
|
||||||
@ -73,13 +70,11 @@ def get_additional_models(conds, dtype):
|
|||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
add_models = []
|
add_models = []
|
||||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
|
||||||
|
|
||||||
control_nets = set(cnets)
|
control_nets = set(cnets)
|
||||||
|
|
||||||
@ -90,11 +85,20 @@ def get_additional_models(conds, dtype):
|
|||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
models = control_models + gligen + add_models
|
||||||
models = control_models + gligen + add_models + hook_models
|
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
|
|
||||||
|
def get_additional_models_from_model_options(model_options: dict[str]=None):
|
||||||
|
"""loads additional models from registered AddModels hooks"""
|
||||||
|
models = []
|
||||||
|
if model_options is not None and "registered_hooks" in model_options:
|
||||||
|
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
|
||||||
|
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||||
|
hook: comfy.hooks.AdditionalModelsHook
|
||||||
|
models.extend(hook.models)
|
||||||
|
return models
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
"""cleanup additional models that were loaded"""
|
"""cleanup additional models that were loaded"""
|
||||||
for m in models:
|
for m in models:
|
||||||
@ -102,9 +106,10 @@ def cleanup_additional_models(models):
|
|||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
real_model: 'BaseModel' = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||||
@ -123,12 +128,35 @@ def cleanup_models(conds, models):
|
|||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||||
|
'''
|
||||||
|
Registers hooks from conds.
|
||||||
|
'''
|
||||||
# check for hooks in conds - if not registered, see if can be applied
|
# check for hooks in conds - if not registered, see if can be applied
|
||||||
hooks = {}
|
hooks = comfy.hooks.HookGroup()
|
||||||
for k in conds:
|
for k in conds:
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||||
# register hooks on model/model_options
|
# begin registering hooks
|
||||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
registered = comfy.hooks.HookGroup()
|
||||||
|
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||||
|
# handle all TransformerOptionsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||||
|
hook: comfy.hooks.TransformerOptionsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all AddModelsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
|
||||||
|
hook: comfy.hooks.AdditionalModelsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all WeightHooks by registering on ModelPatcher
|
||||||
|
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
|
||||||
|
# add registered_hooks onto model_options for further reference
|
||||||
|
if len(registered) > 0:
|
||||||
|
model_options["registered_hooks"] = registered
|
||||||
|
# merge original wrappers and callbacks with hooked wrappers and callbacks
|
||||||
|
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
|
||||||
|
for wc_name in ["wrappers", "callbacks"]:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
|
copy_dict1=False)
|
||||||
|
return to_load_options
|
||||||
|
@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
|||||||
for cond in conds_to_modify:
|
for cond in conds_to_modify:
|
||||||
cond['hooks'] = hooks
|
cond['hooks'] = hooks
|
||||||
|
|
||||||
|
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
|
||||||
|
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
|
||||||
|
HookGroups that have the same reference.'''
|
||||||
|
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
|
||||||
|
# if None were registered, make sure all hooks are cleaned from conds
|
||||||
|
if registered is None:
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
kk.pop('hooks', None)
|
||||||
|
return
|
||||||
|
# find conds that contain hooks to be replaced - group by common HookGroup refs
|
||||||
|
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||||
|
if hooks is not None:
|
||||||
|
if not hooks.is_subset_of(registered):
|
||||||
|
to_replace = hook_replacement.setdefault(hooks, [])
|
||||||
|
to_replace.append(kk)
|
||||||
|
# for each hook to replace, create a new proper HookGroup and assign to all common conds
|
||||||
|
for hooks, conds_to_modify in hook_replacement.items():
|
||||||
|
new_hooks = hooks.new_with_common_hooks(registered)
|
||||||
|
if len(new_hooks) == 0:
|
||||||
|
new_hooks = None
|
||||||
|
for kk in conds_to_modify:
|
||||||
|
kk['hooks'] = new_hooks
|
||||||
|
|
||||||
|
|
||||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||||
hooks_set = set()
|
hooks_set = set()
|
||||||
@ -819,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
|||||||
return len(hooks_set)
|
return len(hooks_set)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||||
|
'''
|
||||||
|
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
||||||
|
'''
|
||||||
|
if model_options is None:
|
||||||
|
return
|
||||||
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
|
if to_load_options is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
casts = []
|
||||||
|
if device is not None:
|
||||||
|
casts.append(device)
|
||||||
|
if dtype is not None:
|
||||||
|
casts.append(dtype)
|
||||||
|
# if nothing to apply, do nothing
|
||||||
|
if len(casts) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# try to call .to on patches
|
||||||
|
if "patches" in to_load_options:
|
||||||
|
patches = to_load_options["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
|
if "patches_replace" in to_load_options:
|
||||||
|
patches = to_load_options["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[k] = patch_list[k].to(cast)
|
||||||
|
# try to call .to on any wrappers/callbacks
|
||||||
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
|
for wc_name in wrappers_and_callbacks:
|
||||||
|
if wc_name in to_load_options:
|
||||||
|
wc: dict[str, list] = to_load_options[wc_name]
|
||||||
|
for wc_dict in wc.values():
|
||||||
|
for wc_list in wc_dict.values():
|
||||||
|
for i in range(len(wc_list)):
|
||||||
|
if hasattr(wc_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
self.model_patcher = model_patcher
|
||||||
self.model_options = model_patcher.model_options
|
self.model_options = model_patcher.model_options
|
||||||
self.original_conds = {}
|
self.original_conds = {}
|
||||||
self.cfg = 1.0
|
self.cfg = 1.0
|
||||||
@ -861,7 +937,7 @@ class CFGGuider:
|
|||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@ -870,6 +946,7 @@ class CFGGuider:
|
|||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
@ -899,6 +976,7 @@ class CFGGuider:
|
|||||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||||
|
filter_registered_hooks_on_conds(self.conds, self.model_options)
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self.outer_sample,
|
self.outer_sample,
|
||||||
self,
|
self,
|
||||||
@ -906,6 +984,7 @@ class CFGGuider:
|
|||||||
)
|
)
|
||||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
|
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||||
self.model_options = orig_model_options
|
self.model_options = orig_model_options
|
||||||
self.model_patcher.hook_mode = orig_hook_mode
|
self.model_patcher.hook_mode = orig_hook_mode
|
||||||
self.model_patcher.restore_hook_patches()
|
self.model_patcher.restore_hook_patches()
|
||||||
|
@ -111,7 +111,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -898,7 +898,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
logging.info("loaded diffusion model directly to GPU")
|
||||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
@ -246,7 +246,7 @@ class SetClipHooks:
|
|||||||
CATEGORY = "advanced/hooks/clip"
|
CATEGORY = "advanced/hooks/clip"
|
||||||
FUNCTION = "apply_hooks"
|
FUNCTION = "apply_hooks"
|
||||||
|
|
||||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
clip = clip.clone()
|
clip = clip.clone()
|
||||||
if apply_to_conds:
|
if apply_to_conds:
|
||||||
@ -255,7 +255,7 @@ class SetClipHooks:
|
|||||||
clip.use_clip_schedule = schedule_clip
|
clip.use_clip_schedule = schedule_clip
|
||||||
if not clip.use_clip_schedule:
|
if not clip.use_clip_schedule:
|
||||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class ConditioningTimestepsRange:
|
class ConditioningTimestepsRange:
|
||||||
|
@ -4,7 +4,8 @@ lint.ignore = ["ALL"]
|
|||||||
# Enable specific rules
|
# Enable specific rules
|
||||||
lint.select = [
|
lint.select = [
|
||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
"T201", # print-usage
|
"S102", # exec
|
||||||
|
"T", # print-usage
|
||||||
"W",
|
"W",
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
|
Loading…
Reference in New Issue
Block a user