Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time

This commit is contained in:
Jedrzej Kosinski 2025-01-05 21:07:02 -06:00
parent db2d7ad9ba
commit 8270ff312f
4 changed files with 119 additions and 35 deletions

View File

@ -65,7 +65,7 @@ class _HookRef:
pass
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how should_register function should look like.'''
return True
@ -114,10 +114,10 @@ class Hook:
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
@ -154,7 +154,7 @@ class WeightHook(Hook):
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
@ -178,7 +178,7 @@ class WeightHook(Hook):
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
registered.add(self)
return True
# TODO: add logs about any keys that were not applied
@ -212,11 +212,12 @@ class AddModelsHook(Hook):
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, key: str=None, models: list[ModelPatcher]=None):
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.key = key
self.append_when_same = True
'''Curently does nothing.'''
def clone(self, subtype: Callable=None):
if subtype is None:
@ -227,9 +228,10 @@ class AddModelsHook(Hook):
c.append_when_same = self.append_when_same
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
registered.add(self)
return True
class TransformerOptionsHook(Hook):
@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook):
c.transformers_dict = self.transformers_dict
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.transformers_dict}
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
else:
add_model_options = {"to_load_options": self.transformers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
registered.add(self)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
@ -295,6 +300,9 @@ class HookGroup:
self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)

View File

@ -940,13 +940,11 @@ class ModelPatcher:
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None):
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle TransformerOptionsHooks, if model_options provided
if model_options is not None:
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
if registered is None:
registered = comfy.hooks.HookGroup()
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
@ -956,9 +954,10 @@ class ModelPatcher:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
hook.add_hook_patches(self, model_options, target_dict, registered)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks, target_dict)
callback(self, hooks, target_dict, model_options, registered)
return registered
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():

View File

@ -70,13 +70,11 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks = comfy.hooks.HookGroup()
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets)
@ -87,14 +85,20 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
hook_models = []
for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
x: comfy.hooks.AddModelsHook
hook_models.extend(x.models)
models = control_models + gligen + add_models + hook_models
models = control_models + gligen + add_models
return models, inference_memory
def get_additional_models_from_model_options(model_options: dict[str]=None):
"""loads additional models from registered AddModels hooks"""
models = []
if model_options is not None and "registered_hooks" in model_options:
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels):
hook: comfy.hooks.AddModelsHook
models.extend(hook.models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
@ -102,9 +106,10 @@ def cleanup_additional_models(models):
m.cleanup()
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
real_model: 'BaseModel' = None
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
@ -130,5 +135,26 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
# handle all TransformerOptionsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook: comfy.hooks.TransformerOptionsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all AddModelsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
hook: comfy.hooks.AddModelsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all WeightHooks by registering on ModelPatcher
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
# add registered_hooks onto model_options for further reference
if len(registered) > 0:
model_options["registered_hooks"] = registered
# merge original wrappers and callbacks with hooked wrappers and callbacks
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
for wc_name in ["wrappers", "callbacks"]:
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options

View File

@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
return len(hooks_set)
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
'''
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
'''
if model_options is None:
return
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
casts = []
if device is not None:
casts.append(device)
if dtype is not None:
casts.append(dtype)
# if nothing to apply, do nothing
if len(casts) == 0:
return
# Try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# Try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
if hasattr(wc_list[i], "to"):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@ -861,7 +910,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
@ -870,6 +919,7 @@ class CFGGuider:
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
self.model_patcher.pre_run()
@ -906,6 +956,7 @@ class CFGGuider:
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()