mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-13 19:07:01 +08:00
Add SA-Solver sampler (#8834)
This commit is contained in:
parent
974254218a
commit
aac10ad23a
121
comfy/k_diffusion/sa_solver.py
Normal file
121
comfy/k_diffusion/sa_solver.py
Normal file
@ -0,0 +1,121 @@
|
||||
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
|
||||
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
|
||||
# Codebase ref: https://github.com/scxue/SA-Solver
|
||||
|
||||
import math
|
||||
from typing import Union, Callable
|
||||
import torch
|
||||
|
||||
|
||||
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
|
||||
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
|
||||
|
||||
Integral of exp((1 + tau^2) * x) * x^p dx
|
||||
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
|
||||
with base case p=0 where integral equals product_terms[0].
|
||||
|
||||
where
|
||||
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
|
||||
|
||||
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
|
||||
Return coefficients used by the SA-Solver in data prediction mode.
|
||||
|
||||
Args:
|
||||
s: Start time s.
|
||||
t: End time t.
|
||||
solver_order: Current order of the solver.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
|
||||
Returns:
|
||||
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
|
||||
"""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = t - s
|
||||
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
|
||||
|
||||
# product_terms after factoring out exp((1 + tau^2) * t)
|
||||
# Includes (1 + tau^2) factor from outside the integral
|
||||
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
|
||||
|
||||
# Lower triangular recursive coefficient matrix
|
||||
# Accumulates recursive coefficients based on p / (1 + tau^2)
|
||||
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
|
||||
log_factorial = (p + 1).lgamma()
|
||||
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
|
||||
if tau_t > 0:
|
||||
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
|
||||
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
|
||||
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
|
||||
|
||||
return recursive_coeff_mat @ product_terms_factored
|
||||
|
||||
|
||||
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = lambda_t - lambda_s
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
if is_corrector_step:
|
||||
# Simplified 1-step (order-2) corrector
|
||||
b_1 = alpha_t * (0.5 * tau_mul * h)
|
||||
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
|
||||
else:
|
||||
# Simplified 2-step predictor
|
||||
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
|
||||
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
|
||||
return torch.stack([b_2, b_1])
|
||||
|
||||
|
||||
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
|
||||
|
||||
The solver order corresponds to the number of input lambdas (half-logSNR points).
|
||||
|
||||
Args:
|
||||
sigma_next: Sigma at end time t.
|
||||
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
|
||||
lambda_s: Lambda at start time s.
|
||||
lambda_t: Lambda at end time t.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
simple_order_2: Whether to enable the simple order-2 scheme.
|
||||
is_corrector_step: Flag for corrector step in simple order-2 mode.
|
||||
|
||||
Returns:
|
||||
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
|
||||
"""
|
||||
num_timesteps = curr_lambdas.shape[0]
|
||||
|
||||
if simple_order_2 and num_timesteps == 2:
|
||||
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
|
||||
|
||||
# Compute coefficients by solving a linear system from Lagrange basis interpolation
|
||||
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
|
||||
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
|
||||
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
|
||||
|
||||
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
|
||||
# = sigma_t * exp(lambda_t) = alpha_t
|
||||
# exp((1 + tau^2) * lambda_t) is extracted from the integral
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
return alpha_t * lagrange_integrals
|
||||
|
||||
|
||||
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
|
||||
"""Return a function that controls the stochasticity of SA-Solver.
|
||||
|
||||
When eta = 0, SA-Solver runs as ODE. The official approach uses
|
||||
time t to determine the SDE interval, while here we use sigma instead.
|
||||
|
||||
See:
|
||||
https://github.com/scxue/SA-Solver/blob/main/README.md
|
||||
"""
|
||||
|
||||
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
|
||||
if eta <= 0:
|
||||
return 0.0 # ODE
|
||||
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.item()
|
||||
return eta if start_sigma >= sigma >= end_sigma else 0.0
|
||||
|
||||
return tau_func
|
@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
@ -1648,3 +1649,113 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
||||
|
||||
if tau_func is None:
|
||||
# Use default interval for stochastic sampling
|
||||
start_sigma = model_sampling.percent_to_sigma(0.2)
|
||||
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
|
||||
|
||||
max_used_order = max(predictor_order, corrector_order)
|
||||
x_pred = x # x: current state, x_pred: predicted next state
|
||||
|
||||
h = 0.0
|
||||
tau_t = 0.0
|
||||
noise = 0.0
|
||||
pred_list = []
|
||||
|
||||
# Lower order near the end to improve stability
|
||||
lower_order_to_end = sigmas[-1].item() == 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
# Evaluation
|
||||
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
pred_list.append(denoised)
|
||||
pred_list = pred_list[-max_used_order:]
|
||||
|
||||
predictor_order_used = min(predictor_order, len(pred_list))
|
||||
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
|
||||
corrector_order_used = 0
|
||||
else:
|
||||
corrector_order_used = min(corrector_order, len(pred_list))
|
||||
|
||||
if lower_order_to_end:
|
||||
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
|
||||
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
|
||||
|
||||
# Corrector
|
||||
if corrector_order_used == 0:
|
||||
# Update by the predicted state
|
||||
x = x_pred
|
||||
else:
|
||||
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i],
|
||||
curr_lambdas,
|
||||
lambdas[i - 1],
|
||||
lambdas[i],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=True,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
|
||||
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
# The noise from the previous predictor step
|
||||
x = x + noise
|
||||
|
||||
if use_pece:
|
||||
# Evaluate the corrected state
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
pred_list[-1] = denoised
|
||||
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i + 1],
|
||||
curr_lambdas,
|
||||
lambdas[i],
|
||||
lambdas[i + 1],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=False,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
|
||||
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
h = lambdas[i + 1] - lambdas[i]
|
||||
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
@ -720,7 +720,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.k_diffusion import sa_solver
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import latent_preview
|
||||
import torch
|
||||
@ -521,6 +522,49 @@ class SamplerER_SDE(ComfyNodeABC):
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class SamplerSASolver(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
"eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},),
|
||||
"sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},),
|
||||
"sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},),
|
||||
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},),
|
||||
"predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}),
|
||||
"corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}),
|
||||
"use_pece": (IO.BOOLEAN, {}),
|
||||
"simple_order_2": (IO.BOOLEAN, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.SAMPLER,)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
|
||||
end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
|
||||
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta)
|
||||
|
||||
sampler_name = "sa_solver"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{
|
||||
"tau_func": tau_func,
|
||||
"s_noise": s_noise,
|
||||
"predictor_order": predictor_order,
|
||||
"corrector_order": corrector_order,
|
||||
"use_pece": use_pece,
|
||||
"simple_order_2": simple_order_2,
|
||||
},
|
||||
)
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -829,6 +873,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SamplerER_SDE": SamplerER_SDE,
|
||||
"SamplerSASolver": SamplerSASolver,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
|
Loading…
x
Reference in New Issue
Block a user