mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Show the right amount of steps in the progress bar for uni_pc.
The extra step doesn't actually call the unet so it doesn't belong in the progress bar.
This commit is contained in:
parent
f10b8948c3
commit
f542f248f1
@ -712,7 +712,7 @@ class UniPC:
|
|||||||
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||||
assert timesteps.shape[0] - 1 == steps
|
assert timesteps.shape[0] - 1 == steps
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
for step_index in trange(steps + 1):
|
for step_index in trange(steps):
|
||||||
if step_index == 0:
|
if step_index == 0:
|
||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
@ -728,29 +728,31 @@ class UniPC:
|
|||||||
model_prev_list.append(model_x)
|
model_prev_list.append(model_x)
|
||||||
t_prev_list.append(vec_t)
|
t_prev_list.append(vec_t)
|
||||||
else:
|
else:
|
||||||
step = step_index
|
extra_final_step = 0
|
||||||
# for step in range(order, steps + 1):
|
if step_index == (steps - 1):
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
extra_final_step = 1
|
||||||
if lower_order_final:
|
for step in range(step_index, step_index + 1 + extra_final_step):
|
||||||
step_order = min(order, steps + 1 - step)
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
else:
|
if lower_order_final:
|
||||||
step_order = order
|
step_order = min(order, steps + 1 - step)
|
||||||
# print('this step order:', step_order)
|
else:
|
||||||
if step == steps:
|
step_order = order
|
||||||
# print('do not run corrector at the last step')
|
# print('this step order:', step_order)
|
||||||
use_corrector = False
|
if step == steps:
|
||||||
else:
|
# print('do not run corrector at the last step')
|
||||||
use_corrector = True
|
use_corrector = False
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
else:
|
||||||
for i in range(order - 1):
|
use_corrector = True
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
for i in range(order - 1):
|
||||||
t_prev_list[-1] = vec_t
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
# We do not need to evaluate the final model value.
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
if step < steps:
|
t_prev_list[-1] = vec_t
|
||||||
if model_x is None:
|
# We do not need to evaluate the final model value.
|
||||||
model_x = self.model_fn(x, vec_t)
|
if step < steps:
|
||||||
model_prev_list[-1] = model_x
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
|
Loading…
Reference in New Issue
Block a user