From 4209edf48dcb72ff73580b7fb19b72b9b8ebba01 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Jan 2025 02:12:12 -0500 Subject: [PATCH] Make a few more samplers deterministic. --- comfy/k_diffusion/sampling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index d37d7dd9..413ad5aa 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -475,7 +475,7 @@ class DPMSolver(nn.Module): return x_3, eps_cache def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler if not t_end > t_start and eta: raise ValueError('eta must be 0 for reverse sampling') @@ -514,7 +514,7 @@ class DPMSolver(nn.Module): return x def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler if order not in {2, 3}: raise ValueError('order should be 2 or 3') forward = t_end > t_start @@ -894,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): @@ -914,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): extra_args = {} if extra_args is None else extra_args - noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args)