diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 42ee1c93..ae3a544e 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -821,7 +821,13 @@ def expand_dims(v, dims): def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None): - timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] + to_zero = False + if sigmas[-1] == 0: + timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] + to_zero = True + else: + timesteps = sigmas.clone() + for s in range(timesteps.shape[0]): timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas)) @@ -832,7 +838,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None else: img = noise - if sigmas[-1] == 0: + if to_zero: timesteps[-1] = (1 / len(model.sigmas)) device = noise.device @@ -853,4 +859,6 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + if not to_zero: + x /= ns.marginal_alpha(timesteps[-1]) return x