mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Use default_noise_sampler instead
This commit is contained in:
parent
812dc34f46
commit
6b68b61644
@ -175,12 +175,6 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
|
||||
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||
return x_t
|
||||
|
||||
def device_noise_sampler(x, noise_device='gpu'):
|
||||
if noise_device == "gpu":
|
||||
return torch.randn_like(x)
|
||||
else:
|
||||
return torch.randn(x.shape, device='cpu').to(x.device)
|
||||
|
||||
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
||||
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
||||
if eta == 0:
|
||||
|
@ -1154,7 +1154,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
||||
tau = tau_func
|
||||
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
sigma_prev_list = []
|
||||
@ -1173,7 +1174,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
|
||||
|
||||
tau_val = tau(sigma)
|
||||
noise = None if tau_val == 0 else noise_sampler()
|
||||
noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1])
|
||||
|
||||
# Predictor step
|
||||
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
|
||||
@ -1226,26 +1227,6 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa
|
||||
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
|
||||
predictor_order=predictor_order, corrector_order=corrector_order,
|
||||
pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
|
||||
predictor_order=predictor_order, corrector_order=corrector_order,
|
||||
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
@ -687,7 +687,7 @@ class Sampler:
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"]
|
||||
"ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_pece"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
@ -423,7 +423,6 @@ class SamplerSASolver:
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"noise_device": (["gpu", "cpu"],),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
@ -431,16 +430,16 @@ class SamplerSASolver:
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent, noise_device):
|
||||
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent):
|
||||
model_sampling = model.get_model_object('model_sampling')
|
||||
start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
|
||||
end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
|
||||
tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
||||
|
||||
if pc_mode == 'PEC':
|
||||
sampler_name = "sa_solver" if noise_device == "cpu" else "sa_solver_gpu"
|
||||
sampler_name = "sa_solver"
|
||||
else:
|
||||
sampler_name = "sa_solver_pece" if noise_device == "cpu" else "sa_solver_pece_gpu"
|
||||
sampler_name = "sa_solver_pece"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func})
|
||||
return (sampler, )
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user