mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas'
This commit is contained in:
parent
a618f768e0
commit
72bbf49349
@ -366,9 +366,15 @@ class HookKeyframe:
|
|||||||
self.start_t = 999999999.9
|
self.start_t = 999999999.9
|
||||||
self.guarantee_steps = guarantee_steps
|
self.guarantee_steps = guarantee_steps
|
||||||
|
|
||||||
|
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
|
||||||
|
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
|
||||||
|
if self.start_t > max_sigma:
|
||||||
|
return 0
|
||||||
|
return self.guarantee_steps
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookKeyframe(strength=self.strength,
|
c = HookKeyframe(strength=self.strength,
|
||||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||||
c.start_t = self.start_t
|
c.start_t = self.start_t
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@ -408,6 +414,12 @@ class HookKeyframeGroup:
|
|||||||
else:
|
else:
|
||||||
self._current_keyframe = None
|
self._current_keyframe = None
|
||||||
|
|
||||||
|
def has_guarantee_steps(self):
|
||||||
|
for kf in self.keyframes:
|
||||||
|
if kf.guarantee_steps > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def has_index(self, index: int):
|
def has_index(self, index: int):
|
||||||
return index >= 0 and index < len(self.keyframes)
|
return index >= 0 and index < len(self.keyframes)
|
||||||
|
|
||||||
@ -425,15 +437,16 @@ class HookKeyframeGroup:
|
|||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||||
|
|
||||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
|
||||||
if self.is_empty():
|
if self.is_empty():
|
||||||
return False
|
return False
|
||||||
if curr_t == self._curr_t:
|
if curr_t == self._curr_t:
|
||||||
return False
|
return False
|
||||||
|
max_sigma = torch.max(transformer_options["sigmas"])
|
||||||
prev_index = self._current_index
|
prev_index = self._current_index
|
||||||
prev_strength = self._current_strength
|
prev_strength = self._current_strength
|
||||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
|
||||||
# if has next index, loop through and see if need to switch
|
# if has next index, loop through and see if need to switch
|
||||||
if self.has_index(self._current_index+1):
|
if self.has_index(self._current_index+1):
|
||||||
for i in range(self._current_index+1, len(self.keyframes)):
|
for i in range(self._current_index+1, len(self.keyframes)):
|
||||||
@ -446,7 +459,7 @@ class HookKeyframeGroup:
|
|||||||
self._current_keyframe = eval_c
|
self._current_keyframe = eval_c
|
||||||
self._current_used_steps = 0
|
self._current_used_steps = 0
|
||||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||||
if self._current_keyframe.guarantee_steps > 0:
|
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||||
break
|
break
|
||||||
# if eval_c is outside the percent range, stop looking further
|
# if eval_c is outside the percent range, stop looking further
|
||||||
else: break
|
else: break
|
||||||
|
@ -919,11 +919,12 @@ class ModelPatcher:
|
|||||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||||
self.hook_mode = hook_mode
|
self.hook_mode = hook_mode
|
||||||
|
|
||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
for hook in hook_group.hooks:
|
for hook in hook_group.hooks:
|
||||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
# this will cause the weights to be recalculated when sampling
|
# this will cause the weights to be recalculated when sampling
|
||||||
if changed:
|
if changed:
|
||||||
|
@ -144,7 +144,7 @@ def cond_cat(c_list):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
|
||||||
# need to figure out remaining unmasked area for conds
|
# need to figure out remaining unmasked area for conds
|
||||||
default_mults = []
|
default_mults = []
|
||||||
for _ in default_conds:
|
for _ in default_conds:
|
||||||
@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
|||||||
# replace p's mult with calculated mult
|
# replace p's mult with calculated mult
|
||||||
p = p._replace(mult=mult)
|
p = p._replace(mult=mult)
|
||||||
if p.hooks is not None:
|
if p.hooks is not None:
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
|
||||||
@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
if p.hooks is not None:
|
if p.hooks is not None:
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
default_conds.append(default_c)
|
default_conds.append(default_c)
|
||||||
@ -840,7 +840,9 @@ class CFGGuider:
|
|||||||
|
|
||||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||||
|
|
||||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||||
|
extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas
|
||||||
|
extra_args = {"model_options": extra_model_options, "seed": seed}
|
||||||
|
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
sampler.sample,
|
sampler.sample,
|
||||||
|
Loading…
Reference in New Issue
Block a user