Show the right amount of steps in the progress bar for uni_pc.

The extra step doesn't actually call the unet so it doesn't belong in
the progress bar.
This commit is contained in:
comfyanonymous 2023-02-11 14:59:42 -05:00
parent f10b8948c3
commit f542f248f1

View File

@ -712,7 +712,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps + 1):
for step_index in trange(steps):
if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
@ -728,29 +728,31 @@ class UniPC:
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
else:
step = step_index
# for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
# print('this step order:', step_order)
if step == steps:
# print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
extra_final_step = 0
if step_index == (steps - 1):
extra_final_step = 1
for step in range(step_index, step_index + 1 + extra_final_step):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
# print('this step order:', step_order)
if step == steps:
# print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
else:
raise NotImplementedError()
if denoise_to_zero: