Filter only registered hooks on self.conds in CFGGuider.sample

This commit is contained in:
Jedrzej Kosinski 2025-01-06 01:04:29 -06:00
parent 03a97b604a
commit 0a7e2ae787
2 changed files with 31 additions and 0 deletions

View File

@ -128,6 +128,9 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup)) cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
'''
Registers hooks from conds.
'''
# check for hooks in conds - if not registered, see if can be applied # check for hooks in conds - if not registered, see if can be applied
hooks = comfy.hooks.HookGroup() hooks = comfy.hooks.HookGroup()
for k in conds: for k in conds:

View File

@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
for cond in conds_to_modify: for cond in conds_to_modify:
cond['hooks'] = hooks cond['hooks'] = hooks
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
HookGroups that have the same reference.'''
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
# if None were registered, make sure all hooks are cleaned from conds
if registered is None:
for k in conds:
for kk in conds[k]:
kk.pop('hooks', None)
return
# find conds that contain hooks to be replaced - group by common HookGroup refs
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
for k in conds:
for kk in conds[k]:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
if hooks is not None:
if not hooks.is_subset_of(registered):
to_replace = hook_replacement.setdefault(hooks, [])
to_replace.append(kk)
# for each hook to replace, create a new proper HookGroup and assign to all common conds
for hooks, conds_to_modify in hook_replacement.items():
new_hooks = hooks.new_with_common_hooks(registered)
if len(new_hooks) == 0:
new_hooks = None
for kk in conds_to_modify:
kk['hooks'] = new_hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set() hooks_set = set()
@ -949,6 +976,7 @@ class CFGGuider:
if get_total_hook_groups_in_conds(self.conds) <= 1: if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
filter_registered_hooks_on_conds(self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample, self.outer_sample,
self, self,