Show the right amount of steps in the progress bar for uni_pc.

The extra step doesn't actually call the unet so it doesn't belong in
the progress bar.
This commit is contained in:
comfyanonymous 2023-02-11 14:59:42 -05:00
parent f10b8948c3
commit f542f248f1

View File

@ -712,7 +712,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps + 1):
for step_index in trange(steps):
if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
@ -728,8 +728,10 @@ class UniPC:
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
else:
step = step_index
# for step in range(order, steps + 1):
extra_final_step = 0
if step_index == (steps - 1):
extra_final_step = 1
for step in range(step_index, step_index + 1 + extra_final_step):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)