diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index a475bc90..cb4bece5 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -175,12 +175,6 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part return x_t -def device_noise_sampler(x, noise_device='gpu'): - if noise_device == "gpu": - return torch.randn_like(x) - else: - return torch.randn(x.shape, device='cpu').to(x.device) - # Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): if eta == 0: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 9fbc62bd..4adc828f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1154,7 +1154,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F end_sigma = model_sampling.percent_to_sigma(0.8) tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) tau = tau_func - noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_prev_list = [] @@ -1173,7 +1174,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1) tau_val = tau(sigma) - noise = None if tau_val == 0 else noise_sampler() + noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1]) # Predictor step x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, @@ -1226,26 +1227,6 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, ) -@torch.no_grad() -def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): - if len(sigmas) <= 1: - return x - noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler - return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, - predictor_order=predictor_order, corrector_order=corrector_order, - pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler, - ) - -@torch.no_grad() -def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): - if len(sigmas) <= 1: - return x - noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler - return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, - predictor_order=predictor_order, corrector_order=corrector_order, - pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, - ) - @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): extra_args = {} if extra_args is None else extra_args diff --git a/comfy/samplers.py b/comfy/samplers.py index cd011e29..bb54a5ca 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -687,7 +687,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"] + "ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 9c8eb4b9..03461d6c 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -423,7 +423,6 @@ class SamplerSASolver: "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), "eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}), - "noise_device": (["gpu", "cpu"],), } } RETURN_TYPES = ("SAMPLER",) @@ -431,16 +430,16 @@ class SamplerSASolver: FUNCTION = "get_sampler" - def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent, noise_device): + def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent): model_sampling = model.get_model_object('model_sampling') start_sigma = model_sampling.percent_to_sigma(eta_start_percent) end_sigma = model_sampling.percent_to_sigma(eta_end_percent) tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) if pc_mode == 'PEC': - sampler_name = "sa_solver" if noise_device == "cpu" else "sa_solver_gpu" + sampler_name = "sa_solver" else: - sampler_name = "sa_solver_pece" if noise_device == "cpu" else "sa_solver_pece_gpu" + sampler_name = "sa_solver_pece" sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func}) return (sampler, )