From 3507870535f9049811fec428088053d304b2319e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Dec 2024 02:42:49 -0600 Subject: [PATCH] Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' (#6273) --- comfy/hooks.py | 21 +++++++++++++++++---- comfy/model_patcher.py | 5 +++-- comfy/samplers.py | 10 ++++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cf33598a..79a7090b 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -366,9 +366,15 @@ class HookKeyframe: self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): c = HookKeyframe(strength=self.strength, - start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) c.start_t = self.start_t return c @@ -408,6 +414,12 @@ class HookKeyframeGroup: else: self._current_keyframe = None + def has_guarantee_steps(self): + for kf in self.keyframes: + if kf.guarantee_steps > 0: + return True + return False + def has_index(self, index: int): return index >= 0 and index < len(self.keyframes) @@ -425,15 +437,16 @@ class HookKeyframeGroup: for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, curr_t: float) -> bool: + def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool: if self.is_empty(): return False if curr_t == self._curr_t: return False + max_sigma = torch.max(transformer_options["sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: + if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need to switch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.keyframes)): @@ -446,7 +459,7 @@ class HookKeyframeGroup: self._current_keyframe = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: + if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: break diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d89d9a6a..4597ce11 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -919,11 +919,12 @@ class ModelPatcher: def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode - def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: - changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) + changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: diff --git a/comfy/samplers.py b/comfy/samplers.py index 27686722..6a386511 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -144,7 +144,7 @@ def cond_cat(c_list): return out -def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): +def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options): # need to figure out remaining unmasked area for conds default_mults = [] for _ in default_conds: @@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H # replace p's mult with calculated mult p = p._replace(mult=mult) if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] @@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if p is None: continue if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] default_conds.append(default_c) @@ -840,7 +840,9 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} + extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) + extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( sampler.sample,