From aac10ad23a8f65243fcfca7afa7a72c1740312d6 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Wed, 9 Jul 2025 04:17:06 +0800 Subject: [PATCH] Add SA-Solver sampler (#8834) --- comfy/k_diffusion/sa_solver.py | 121 +++++++++++++++++++++++++++ comfy/k_diffusion/sampling.py | 111 ++++++++++++++++++++++++ comfy/samplers.py | 2 +- comfy_extras/nodes_custom_sampler.py | 45 ++++++++++ 4 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 comfy/k_diffusion/sa_solver.py diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py new file mode 100644 index 000000000..0c6821b60 --- /dev/null +++ b/comfy/k_diffusion/sa_solver.py @@ -0,0 +1,121 @@ +# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019) +# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf +# Codebase ref: https://github.com/scxue/SA-Solver + +import math +from typing import Union, Callable +import torch + + +def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor: + """Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts. + + Integral of exp((1 + tau^2) * x) * x^p dx + = product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx, + with base case p=0 where integral equals product_terms[0]. + + where + product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2). + + Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1). + Return coefficients used by the SA-Solver in data prediction mode. + + Args: + s: Start time s. + t: End time t. + solver_order: Current order of the solver. + tau_t: Stochastic strength parameter in the SDE. + + Returns: + Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,). + """ + tau_mul = 1 + tau_t ** 2 + h = t - s + p = torch.arange(solver_order, dtype=s.dtype, device=s.device) + + # product_terms after factoring out exp((1 + tau^2) * t) + # Includes (1 + tau^2) factor from outside the integral + product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp()) + + # Lower triangular recursive coefficient matrix + # Accumulates recursive coefficients based on p / (1 + tau^2) + recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0) + log_factorial = (p + 1).lgamma() + recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0) + if tau_t > 0: + recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul)) + signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0) + recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril() + + return recursive_coeff_mat @ product_terms_factored + + +def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor: + """Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details).""" + tau_mul = 1 + tau_t ** 2 + h = lambda_t - lambda_s + alpha_t = sigma_next * lambda_t.exp() + if is_corrector_step: + # Simplified 1-step (order-2) corrector + b_1 = alpha_t * (0.5 * tau_mul * h) + b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1 + else: + # Simplified 2-step predictor + b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s) + b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2 + return torch.stack([b_2, b_1]) + + +def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor: + """Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18). + + The solver order corresponds to the number of input lambdas (half-logSNR points). + + Args: + sigma_next: Sigma at end time t. + curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,). + lambda_s: Lambda at start time s. + lambda_t: Lambda at end time t. + tau_t: Stochastic strength parameter in the SDE. + simple_order_2: Whether to enable the simple order-2 scheme. + is_corrector_step: Flag for corrector step in simple order-2 mode. + + Returns: + b_i coefficients for the SA-Solver, shape (N,), where N is the solver order. + """ + num_timesteps = curr_lambdas.shape[0] + + if simple_order_2 and num_timesteps == 2: + return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step) + + # Compute coefficients by solving a linear system from Lagrange basis interpolation + exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t) + vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T + lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs) + + # (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t) + # = sigma_t * exp(lambda_t) = alpha_t + # exp((1 + tau^2) * lambda_t) is extracted from the integral + alpha_t = sigma_next * lambda_t.exp() + return alpha_t * lagrange_integrals + + +def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]: + """Return a function that controls the stochasticity of SA-Solver. + + When eta = 0, SA-Solver runs as ODE. The official approach uses + time t to determine the SDE interval, while here we use sigma instead. + + See: + https://github.com/scxue/SA-Solver/blob/main/README.md + """ + + def tau_func(sigma: Union[torch.Tensor, float]) -> float: + if eta <= 0: + return 0.0 # ODE + + if isinstance(sigma, torch.Tensor): + sigma = sigma.item() + return eta if start_sigma >= sigma >= end_sigma else 0.0 + + return tau_func diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 34218337a..2ed415b1f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis +from . import sa_solver import comfy.model_patcher import comfy.model_sampling @@ -1648,3 +1649,113 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x + + +@torch.no_grad() +def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False): + """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023).""" + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling) + + if tau_func is None: + # Use default interval for stochastic sampling + start_sigma = model_sampling.percent_to_sigma(0.2) + end_sigma = model_sampling.percent_to_sigma(0.8) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0) + + max_used_order = max(predictor_order, corrector_order) + x_pred = x # x: current state, x_pred: predicted next state + + h = 0.0 + tau_t = 0.0 + noise = 0.0 + pred_list = [] + + # Lower order near the end to improve stability + lower_order_to_end = sigmas[-1].item() == 0 + + for i in trange(len(sigmas) - 1, disable=disable): + # Evaluation + denoised = model(x_pred, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + pred_list.append(denoised) + pred_list = pred_list[-max_used_order:] + + predictor_order_used = min(predictor_order, len(pred_list)) + if i == 0 or (sigmas[i + 1] == 0 and not use_pece): + corrector_order_used = 0 + else: + corrector_order_used = min(corrector_order, len(pred_list)) + + if lower_order_to_end: + predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i) + corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i) + + # Corrector + if corrector_order_used == 0: + # Update by the predicted state + x = x_pred + else: + curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1] + b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( + sigmas[i], + curr_lambdas, + lambdas[i - 1], + lambdas[i], + tau_t, + simple_order_2, + is_corrector_step=True, + ) + pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...) + corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) + x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res + + if tau_t > 0 and s_noise > 0: + # The noise from the previous predictor step + x = x + noise + + if use_pece: + # Evaluate the corrected state + denoised = model(x, sigmas[i] * s_in, **extra_args) + pred_list[-1] = denoised + + # Predictor + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + tau_t = tau_func(sigmas[i + 1]) + curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] + b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( + sigmas[i + 1], + curr_lambdas, + lambdas[i], + lambdas[i + 1], + tau_t, + simple_order_2, + is_corrector_step=False, + ) + pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...) + pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) + h = lambdas[i + 1] - lambdas[i] + x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res + + if tau_t > 0 and s_noise > 0: + noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise + x_pred = x_pred + noise + return x + + +@torch.no_grad() +def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): + """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" + return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) diff --git a/comfy/samplers.py b/comfy/samplers.py index 25ccaf39f..e93d2a315 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -720,7 +720,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index b3a772714..33bc41842 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -2,6 +2,7 @@ import math import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy.k_diffusion import sa_solver from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict import latent_preview import torch @@ -521,6 +522,49 @@ class SamplerER_SDE(ComfyNodeABC): return (sampler,) +class SamplerSASolver(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "model": (IO.MODEL, {}), + "eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},), + "sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},), + "sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},), + "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},), + "predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}), + "corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}), + "use_pece": (IO.BOOLEAN, {}), + "simple_order_2": (IO.BOOLEAN, {}), + } + } + + RETURN_TYPES = (IO.SAMPLER,) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + model_sampling = model.get_model_object("model_sampling") + start_sigma = model_sampling.percent_to_sigma(sde_start_percent) + end_sigma = model_sampling.percent_to_sigma(sde_end_percent) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta) + + sampler_name = "sa_solver" + sampler = comfy.samplers.ksampler( + sampler_name, + { + "tau_func": tau_func, + "s_noise": s_noise, + "predictor_order": predictor_order, + "corrector_order": corrector_order, + "use_pece": use_pece, + "simple_order_2": simple_order_2, + }, + ) + return (sampler,) + + class Noise_EmptyNoise: def __init__(self): self.seed = 0 @@ -829,6 +873,7 @@ NODE_CLASS_MAPPINGS = { "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, "SamplerDPMAdaptative": SamplerDPMAdaptative, "SamplerER_SDE": SamplerER_SDE, + "SamplerSASolver": SamplerSASolver, "SplitSigmas": SplitSigmas, "SplitSigmasDenoise": SplitSigmasDenoise, "FlipSigmas": FlipSigmas,