mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 20:42:11 +08:00
Fix some edge cases with samplers and arrays with a single sigma.
This commit is contained in:
parent
94d5a12801
commit
f81a6fade8
@ -527,6 +527,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
@ -595,6 +598,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
"""DPM-Solver++(2M) SDE."""
|
"""DPM-Solver++(2M) SDE."""
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
@ -642,6 +647,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
|
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
@ -690,18 +698,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
|
@ -539,6 +539,9 @@ class KSAMPLER(Sampler):
|
|||||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||||
if sampler_name == "dpm_fast":
|
if sampler_name == "dpm_fast":
|
||||||
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return noise
|
||||||
|
|
||||||
sigma_min = sigmas[-1]
|
sigma_min = sigmas[-1]
|
||||||
if sigma_min == 0:
|
if sigma_min == 0:
|
||||||
sigma_min = sigmas[-2]
|
sigma_min = sigmas[-2]
|
||||||
@ -547,6 +550,9 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
|||||||
sampler_function = dpm_fast_function
|
sampler_function = dpm_fast_function
|
||||||
elif sampler_name == "dpm_adaptive":
|
elif sampler_name == "dpm_adaptive":
|
||||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return noise
|
||||||
|
|
||||||
sigma_min = sigmas[-1]
|
sigma_min = sigmas[-1]
|
||||||
if sigma_min == 0:
|
if sigma_min == 0:
|
||||||
sigma_min = sigmas[-2]
|
sigma_min = sigmas[-2]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user