Show the right amount of steps in the progress bar for uni_pc.

The extra step doesn't actually call the unet so it doesn't belong in
the progress bar.
This commit is contained in:
comfyanonymous 2023-02-11 14:59:42 -05:00
parent f10b8948c3
commit f542f248f1

View File

@ -712,7 +712,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps + 1): for step_index in trange(steps):
if step_index == 0: if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
@ -728,29 +728,31 @@ class UniPC:
model_prev_list.append(model_x) model_prev_list.append(model_x)
t_prev_list.append(vec_t) t_prev_list.append(vec_t)
else: else:
step = step_index extra_final_step = 0
# for step in range(order, steps + 1): if step_index == (steps - 1):
vec_t = timesteps[step].expand(x.shape[0]) extra_final_step = 1
if lower_order_final: for step in range(step_index, step_index + 1 + extra_final_step):
step_order = min(order, steps + 1 - step) vec_t = timesteps[step].expand(x.shape[0])
else: if lower_order_final:
step_order = order step_order = min(order, steps + 1 - step)
# print('this step order:', step_order) else:
if step == steps: step_order = order
# print('do not run corrector at the last step') # print('this step order:', step_order)
use_corrector = False if step == steps:
else: # print('do not run corrector at the last step')
use_corrector = True use_corrector = False
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) else:
for i in range(order - 1): use_corrector = True
t_prev_list[i] = t_prev_list[i + 1] x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
model_prev_list[i] = model_prev_list[i + 1] for i in range(order - 1):
t_prev_list[-1] = vec_t t_prev_list[i] = t_prev_list[i + 1]
# We do not need to evaluate the final model value. model_prev_list[i] = model_prev_list[i + 1]
if step < steps: t_prev_list[-1] = vec_t
if model_x is None: # We do not need to evaluate the final model value.
model_x = self.model_fn(x, vec_t) if step < steps:
model_prev_list[-1] = model_x if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero: