mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Make a few more samplers deterministic.
This commit is contained in:
parent
d055325783
commit
4209edf48d
@ -475,7 +475,7 @@ class DPMSolver(nn.Module):
|
||||
return x_3, eps_cache
|
||||
|
||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||
if not t_end > t_start and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
|
||||
@ -514,7 +514,7 @@ class DPMSolver(nn.Module):
|
||||
return x
|
||||
|
||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
||||
if order not in {2, 3}:
|
||||
raise ValueError('order should be 2 or 3')
|
||||
forward = t_end > t_start
|
||||
@ -894,7 +894,8 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
@ -914,7 +915,8 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
|
Loading…
Reference in New Issue
Block a user