mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time
This commit is contained in:
parent
db2d7ad9ba
commit
8270ff312f
@ -65,7 +65,7 @@ class _HookRef:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
'''Example for how should_register function should look like.'''
|
'''Example for how should_register function should look like.'''
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -114,10 +114,10 @@ class Hook:
|
|||||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||||
|
|
||||||
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
|
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
@ -154,7 +154,7 @@ class WeightHook(Hook):
|
|||||||
def strength_clip(self):
|
def strength_clip(self):
|
||||||
return self._strength_clip * self.strength
|
return self._strength_clip * self.strength
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
weights = None
|
weights = None
|
||||||
@ -178,7 +178,7 @@ class WeightHook(Hook):
|
|||||||
else:
|
else:
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||||
registered.append(self)
|
registered.add(self)
|
||||||
return True
|
return True
|
||||||
# TODO: add logs about any keys that were not applied
|
# TODO: add logs about any keys that were not applied
|
||||||
|
|
||||||
@ -212,11 +212,12 @@ class AddModelsHook(Hook):
|
|||||||
|
|
||||||
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||||
'''
|
'''
|
||||||
def __init__(self, key: str=None, models: list[ModelPatcher]=None):
|
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||||
super().__init__(hook_type=EnumHookType.AddModels)
|
super().__init__(hook_type=EnumHookType.AddModels)
|
||||||
self.key = key
|
|
||||||
self.models = models
|
self.models = models
|
||||||
|
self.key = key
|
||||||
self.append_when_same = True
|
self.append_when_same = True
|
||||||
|
'''Curently does nothing.'''
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
@ -227,9 +228,10 @@ class AddModelsHook(Hook):
|
|||||||
c.append_when_same = self.append_when_same
|
c.append_when_same = self.append_when_same
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
|
registered.add(self)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class TransformerOptionsHook(Hook):
|
class TransformerOptionsHook(Hook):
|
||||||
@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook):
|
|||||||
c.transformers_dict = self.transformers_dict
|
c.transformers_dict = self.transformers_dict
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
if not self.should_register(model, model_options, target_dict, registered):
|
||||||
return False
|
return False
|
||||||
add_model_options = {"transformer_options": self.transformers_dict}
|
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||||
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
|
|
||||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||||
|
add_model_options = {"transformer_options": self.transformers_dict,
|
||||||
|
"to_load_options": self.transformers_dict}
|
||||||
|
else:
|
||||||
|
add_model_options = {"to_load_options": self.transformers_dict}
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||||
registered.append(self)
|
registered.add(self)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||||
@ -295,6 +300,9 @@ class HookGroup:
|
|||||||
self.hooks: list[Hook] = []
|
self.hooks: list[Hook] = []
|
||||||
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.hooks)
|
||||||
|
|
||||||
def add(self, hook: Hook):
|
def add(self, hook: Hook):
|
||||||
if hook not in self.hooks:
|
if hook not in self.hooks:
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
@ -940,13 +940,11 @@ class ModelPatcher:
|
|||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None):
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
|
registered: comfy.hooks.HookGroup = None):
|
||||||
self.restore_hook_patches()
|
self.restore_hook_patches()
|
||||||
registered_hooks: list[comfy.hooks.Hook] = []
|
if registered is None:
|
||||||
# handle TransformerOptionsHooks, if model_options provided
|
registered = comfy.hooks.HookGroup()
|
||||||
if model_options is not None:
|
|
||||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
|
||||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
|
||||||
# handle WeightHooks
|
# handle WeightHooks
|
||||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||||
@ -956,9 +954,10 @@ class ModelPatcher:
|
|||||||
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
||||||
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
||||||
for hook in weight_hooks_to_register:
|
for hook in weight_hooks_to_register:
|
||||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
hook.add_hook_patches(self, model_options, target_dict, registered)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||||
callback(self, hooks, target_dict)
|
callback(self, hooks, target_dict, model_options, registered)
|
||||||
|
return registered
|
||||||
|
|
||||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
@ -70,13 +70,11 @@ def get_additional_models(conds, dtype):
|
|||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
add_models = []
|
add_models = []
|
||||||
hooks = comfy.hooks.HookGroup()
|
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
|
||||||
|
|
||||||
control_nets = set(cnets)
|
control_nets = set(cnets)
|
||||||
|
|
||||||
@ -87,14 +85,20 @@ def get_additional_models(conds, dtype):
|
|||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
hook_models = []
|
models = control_models + gligen + add_models
|
||||||
for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
|
|
||||||
x: comfy.hooks.AddModelsHook
|
|
||||||
hook_models.extend(x.models)
|
|
||||||
models = control_models + gligen + add_models + hook_models
|
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
|
|
||||||
|
def get_additional_models_from_model_options(model_options: dict[str]=None):
|
||||||
|
"""loads additional models from registered AddModels hooks"""
|
||||||
|
models = []
|
||||||
|
if model_options is not None and "registered_hooks" in model_options:
|
||||||
|
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
|
||||||
|
for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels):
|
||||||
|
hook: comfy.hooks.AddModelsHook
|
||||||
|
models.extend(hook.models)
|
||||||
|
return models
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
"""cleanup additional models that were loaded"""
|
"""cleanup additional models that were loaded"""
|
||||||
for m in models:
|
for m in models:
|
||||||
@ -102,9 +106,10 @@ def cleanup_additional_models(models):
|
|||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
real_model: 'BaseModel' = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||||
@ -130,5 +135,26 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
|||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||||
# register hooks on model/model_options
|
# begin registering hooks
|
||||||
model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options)
|
registered = comfy.hooks.HookGroup()
|
||||||
|
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||||
|
# handle all TransformerOptionsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||||
|
hook: comfy.hooks.TransformerOptionsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all AddModelsHooks
|
||||||
|
for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
|
||||||
|
hook: comfy.hooks.AddModelsHook
|
||||||
|
hook.add_hook_patches(model, model_options, target_dict, registered)
|
||||||
|
# handle all WeightHooks by registering on ModelPatcher
|
||||||
|
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
|
||||||
|
# add registered_hooks onto model_options for further reference
|
||||||
|
if len(registered) > 0:
|
||||||
|
model_options["registered_hooks"] = registered
|
||||||
|
# merge original wrappers and callbacks with hooked wrappers and callbacks
|
||||||
|
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
|
||||||
|
for wc_name in ["wrappers", "callbacks"]:
|
||||||
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
|
copy_dict1=False)
|
||||||
|
return to_load_options
|
||||||
|
|
||||||
|
@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
|||||||
return len(hooks_set)
|
return len(hooks_set)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||||
|
'''
|
||||||
|
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
||||||
|
'''
|
||||||
|
if model_options is None:
|
||||||
|
return
|
||||||
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
|
if to_load_options is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
casts = []
|
||||||
|
if device is not None:
|
||||||
|
casts.append(device)
|
||||||
|
if dtype is not None:
|
||||||
|
casts.append(dtype)
|
||||||
|
# if nothing to apply, do nothing
|
||||||
|
if len(casts) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try to call .to on patches
|
||||||
|
if "patches" in to_load_options:
|
||||||
|
patches = to_load_options["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
|
if "patches_replace" in to_load_options:
|
||||||
|
patches = to_load_options["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
patch_list[k] = patch_list[k].to(cast)
|
||||||
|
# Try to call .to on any wrappers/callbacks
|
||||||
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
|
for wc_name in wrappers_and_callbacks:
|
||||||
|
if wc_name in to_load_options:
|
||||||
|
wc: dict[str, list] = to_load_options[wc_name]
|
||||||
|
for wc_dict in wc.values():
|
||||||
|
for wc_list in wc_dict.values():
|
||||||
|
for i in range(len(wc_list)):
|
||||||
|
if hasattr(wc_list[i], "to"):
|
||||||
|
for cast in casts:
|
||||||
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
self.model_patcher = model_patcher
|
||||||
self.model_options = model_patcher.model_options
|
self.model_options = model_patcher.model_options
|
||||||
self.original_conds = {}
|
self.original_conds = {}
|
||||||
self.cfg = 1.0
|
self.cfg = 1.0
|
||||||
@ -861,7 +910,7 @@ class CFGGuider:
|
|||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@ -870,6 +919,7 @@ class CFGGuider:
|
|||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
@ -906,6 +956,7 @@ class CFGGuider:
|
|||||||
)
|
)
|
||||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
|
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||||
self.model_options = orig_model_options
|
self.model_options = orig_model_options
|
||||||
self.model_patcher.hook_mode = orig_hook_mode
|
self.model_patcher.hook_mode = orig_hook_mode
|
||||||
self.model_patcher.restore_hook_patches()
|
self.model_patcher.restore_hook_patches()
|
||||||
|
Loading…
Reference in New Issue
Block a user