Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' (#6273)

This commit is contained in:
Jedrzej Kosinski 2024-12-30 02:42:49 -06:00 committed by GitHub
parent 82ecb02c1e
commit 3507870535
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 10 deletions

View File

@ -366,6 +366,12 @@ class HookKeyframe:
self.start_t = 999999999.9 self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps self.guarantee_steps = guarantee_steps
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
if self.start_t > max_sigma:
return 0
return self.guarantee_steps
def clone(self): def clone(self):
c = HookKeyframe(strength=self.strength, c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
@ -408,6 +414,12 @@ class HookKeyframeGroup:
else: else:
self._current_keyframe = None self._current_keyframe = None
def has_guarantee_steps(self):
for kf in self.keyframes:
if kf.guarantee_steps > 0:
return True
return False
def has_index(self, index: int): def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes) return index >= 0 and index < len(self.keyframes)
@ -425,15 +437,16 @@ class HookKeyframeGroup:
for keyframe in self.keyframes: for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool: def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
if self.is_empty(): if self.is_empty():
return False return False
if curr_t == self._curr_t: if curr_t == self._curr_t:
return False return False
max_sigma = torch.max(transformer_options["sigmas"])
prev_index = self._current_index prev_index = self._current_index
prev_strength = self._current_strength prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch # if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps: if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
# if has next index, loop through and see if need to switch # if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1): if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)): for i in range(self._current_index+1, len(self.keyframes)):
@ -446,7 +459,7 @@ class HookKeyframeGroup:
self._current_keyframe = eval_c self._current_keyframe = eval_c
self._current_used_steps = 0 self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes # if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0: if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break break
# if eval_c is outside the percent range, stop looking further # if eval_c is outside the percent range, stop looking further
else: break else: break

View File

@ -919,11 +919,12 @@ class ModelPatcher:
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0] curr_t = t[0]
reset_current_hooks = False reset_current_hooks = False
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks: for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling # this will cause the weights to be recalculated when sampling
if changed: if changed:

View File

@ -144,7 +144,7 @@ def cond_cat(c_list):
return out return out
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
# need to figure out remaining unmasked area for conds # need to figure out remaining unmasked area for conds
default_mults = [] default_mults = []
for _ in default_conds: for _ in default_conds:
@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
# replace p's mult with calculated mult # replace p's mult with calculated mult
p = p._replace(mult=mult) p = p._replace(mult=mult)
if p.hooks is not None: if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list()) hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)] hooked_to_run[p.hooks] += [(p, i)]
@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
if p is None: if p is None:
continue continue
if p.hooks is not None: if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list()) hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)] hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c) default_conds.append(default_c)
@ -840,7 +840,9 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas
extra_args = {"model_options": extra_model_options, "seed": seed}
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample, sampler.sample,