diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index cc629912..42ee1c93 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -712,7 +712,7 @@ class UniPC: # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): - for step_index in trange(steps + 1): + for step_index in trange(steps): if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -728,29 +728,31 @@ class UniPC: model_prev_list.append(model_x) t_prev_list.append(vec_t) else: - step = step_index - # for step in range(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final: - step_order = min(order, steps + 1 - step) - else: - step_order = order - # print('this step order:', step_order) - if step == steps: - # print('do not run corrector at the last step') - use_corrector = False - else: - use_corrector = True - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - if model_x is None: - model_x = self.model_fn(x, vec_t) - model_prev_list[-1] = model_x + extra_final_step = 0 + if step_index == (steps - 1): + extra_final_step = 1 + for step in range(step_index, step_index + 1 + extra_final_step): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + # print('this step order:', step_order) + if step == steps: + # print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list[-1] = model_x else: raise NotImplementedError() if denoise_to_zero: