From 92a630e7370a9563de6537712a5c13b435a9aca2 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sat, 17 Aug 2024 23:29:07 +0800 Subject: [PATCH 1/8] Add SA-Solver --- comfy/k_diffusion/sa_solver.py | 189 +++++++++++++++++++++++++++ comfy/k_diffusion/sampling.py | 112 ++++++++++++++++ comfy/samplers.py | 2 +- comfy_extras/nodes_custom_sampler.py | 33 +++++ 4 files changed, 335 insertions(+), 1 deletion(-) create mode 100644 comfy/k_diffusion/sa_solver.py diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py new file mode 100644 index 00000000..4f226c38 --- /dev/null +++ b/comfy/k_diffusion/sa_solver.py @@ -0,0 +1,189 @@ +# Modify from: https://github.com/scxue/SA-Solver +# MIT license + +import torch + +def get_coefficients_exponential_positive(order, interval_start, interval_end, tau): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + For calculating the coefficient of gradient terms after the lagrange interpolation, + see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + For data_prediction formula. + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau ** 2) * interval_end + interval_start_cov = (1 + tau ** 2) * interval_start + + if order == 0: + return (torch.exp(interval_end_cov) + * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) + / ((1 + tau ** 2)) + ) + elif order == 1: + return (torch.exp(interval_end_cov) + * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov))) + / ((1 + tau ** 2) ** 2) + ) + elif order == 2: + return (torch.exp(interval_end_cov) + * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) + - (interval_start_cov ** 2 - 2 * interval_start_cov + 2) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau ** 2) ** 3) + ) + elif order == 3: + return (torch.exp(interval_end_cov) + * ((interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) + - (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau ** 2) ** 4) + ) + +def lagrange_polynomial_coefficient(order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + For lagrange interpolation + """ + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1.0]] + elif order == 1: + return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], + [1.0 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + return [[1.0 / denominator1, (-lambda_list[1] - lambda_list[2]) / denominator1, lambda_list[1] * lambda_list[2] / denominator1], + [1.0 / denominator2, (-lambda_list[0] - lambda_list[2]) / denominator2, lambda_list[0] * lambda_list[2] / denominator2], + [1.0 / denominator3, (-lambda_list[0] - lambda_list[1]) / denominator3, lambda_list[0] * lambda_list[1] / denominator3] + ] + elif order == 3: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (lambda_list[0] - lambda_list[3]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (lambda_list[1] - lambda_list[3]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (lambda_list[2] - lambda_list[3]) + denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (lambda_list[3] - lambda_list[2]) + return [[1.0 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1], + + [1.0 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2], + + [1.0 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3], + + [1.0 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2]) / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4] + ] + +def get_coefficients_fn(order, interval_start, interval_end, lambda_list, tau): + """ + Calculate the coefficient of gradients. + """ + assert order in [1, 2, 3, 4] + assert order == len(lambda_list), 'the length of lambda list must be equal to the order' + lagrange_coefficient = lagrange_polynomial_coefficient(order - 1, lambda_list) + coefficients = [sum(lagrange_coefficient[i][j] * get_coefficients_exponential_positive(order - 1 - j, interval_start, interval_end, tau) + for j in range(order)) + for i in range(order)] + assert len(coefficients) == order, 'the length of coefficients does not match the order' + return coefficients + +def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma): + """ + SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + t_fn = lambda sigma: sigma.log().neg() + sigma_prev = sigma_prev_list[-1] + gradient_part = torch.zeros_like(x) + lambda_list = [t_fn(sigma_prev_list[-(i + 1)]) for i in range(order)] + lambda_t = t_fn(sigma) + lambda_prev = lambda_list[0] + h = lambda_t - lambda_prev + gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau) + + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t) + * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2)) + / (lambda_prev - lambda_list[1]) + ) + gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t) + * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2)) + / (lambda_prev - lambda_list[1]) + ) + + for i in range(order): + gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] + gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) + noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise + x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part + return x_t + +def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma): + """ + SA-Corrector, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + t_fn = lambda sigma: sigma.log().neg() + sigma_prev = sigma_prev_list[-1] + gradient_part = torch.zeros_like(x) + sigma_list = sigma_prev_list + [sigma] + lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)] + lambda_t = lambda_list[0] + lambda_prev = lambda_list[1] + h = lambda_t - lambda_prev + gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau) + + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t) + * (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) + / ((1 + tau ** 2) ** 2 * h)) + ) + gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t) + * (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) + / ((1 + tau ** 2) ** 2 * h)) + ) + + for i in range(order): + gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] + gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) + noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise + x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part + return x_t + +def device_noise_sampler(x, noise_device='gpu'): + if noise_device == "gpu": + return torch.randn_like(x) + else: + return torch.randn(x.shape, device='cpu').to(x.device) + +# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract +def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): + if eta == 0: + # Pure ODE + return 0 + return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0 \ No newline at end of file diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index e9e4edcc..1bc3a27d 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +from functools import partial from scipy import integrate import torch @@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis +from . import sa_solver import comfy.model_patcher import comfy.model_sampling @@ -1061,6 +1063,116 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, return x_next +# Modify from: https://github.com/scxue/SA-Solver +# MIT license +@torch.no_grad() +def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, pc_mode="PEC", tau_func=None, noise_sampler=None): + if len(sigmas) <= 1: + return x + + if sigmas[-1] == 0: + sigmas = sigmas.clone() + sigmas[-1] = 0.001 + + extra_args = {} if extra_args is None else extra_args + if tau_func is None: + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + start_sigma = model_sampling.percent_to_sigma(0.2) + end_sigma = model_sampling.percent_to_sigma(0.8) + tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) + tau = tau_func + noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + sigma_prev_list = [] + model_prev_list = [] + + for i in trange(len(sigmas) - 1, disable=disable): + sigma = sigmas[i] + if i == 0: + # Init the initial values. + denoised = model(x, sigma * s_in, **extra_args) + model_prev_list.append(denoised) + sigma_prev_list.append(sigma) + else: + # Lower order final + predictor_order_used = min(predictor_order, i, len(sigmas) - i) + corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1) + + tau_val = tau(sigma) + noise = None if tau_val == 0 else noise_sampler() + + # Predictor step + x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=noise, sigma=sigma) + + # Evaluation step + denoised = model(x_p, sigma * s_in, **extra_args) + + # Update model_list + model_prev_list.append(denoised) + + # Corrector step + if corrector_order_used > 0: + x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=noise, sigma=sigma) + + else: + x = x_p + + del noise, x_p + + # Evaluation step if mode = pece and step != steps + if corrector_order_used > 0 and pc_mode == 'PECE': + del model_prev_list[-1] + denoised = model(x, sigma * s_in, **extra_args) + model_prev_list.append(denoised) + + sigma_prev_list.append(sigma) + if len(model_prev_list) > max(predictor_order, corrector_order): + del model_prev_list[0] + del sigma_prev_list[0] + + if callback is not None: + callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) + + # Extra final step + x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=0, sigma=sigmas[-1]) + return x + +@torch.no_grad() +def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): + if len(sigmas) <= 1: + return x + return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, + predictor_order=predictor_order, corrector_order=corrector_order, + pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, + ) + +@torch.no_grad() +def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): + if len(sigmas) <= 1: + return x + noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler + return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, + predictor_order=predictor_order, corrector_order=corrector_order, + pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler, + ) + +@torch.no_grad() +def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): + if len(sigmas) <= 1: + return x + noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler + return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, + predictor_order=predictor_order, corrector_order=corrector_order, + pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, + ) + @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): extra_args = {} if extra_args is None else extra_args diff --git a/comfy/samplers.py b/comfy/samplers.py index a20f65d4..0787c17c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -572,7 +572,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis"] + "ipndm", "ipndm_v", "deis", 'sa_solver', "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c7ff9a4d..9c8eb4b9 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -1,6 +1,8 @@ +from functools import partial import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy.k_diffusion import sa_solver import latent_preview import torch import comfy.utils @@ -411,6 +413,36 @@ class SamplerDPMAdaptative: "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, "s_noise":s_noise }) return (sampler, ) + +class SamplerSASolver: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "pc_mode": (['PEC', "PECE"],), + "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), + "eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}), + "noise_device": (["gpu", "cpu"],), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent, noise_device): + model_sampling = model.get_model_object('model_sampling') + start_sigma = model_sampling.percent_to_sigma(eta_start_percent) + end_sigma = model_sampling.percent_to_sigma(eta_end_percent) + tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) + + if pc_mode == 'PEC': + sampler_name = "sa_solver" if noise_device == "cpu" else "sa_solver_gpu" + else: + sampler_name = "sa_solver_pece" if noise_device == "cpu" else "sa_solver_pece_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func}) + return (sampler, ) class Noise_EmptyNoise: def __init__(self): @@ -707,6 +739,7 @@ NODE_CLASS_MAPPINGS = { "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, "SamplerDPMAdaptative": SamplerDPMAdaptative, + "SamplerSASolver": SamplerSASolver, "SplitSigmas": SplitSigmas, "SplitSigmasDenoise": SplitSigmasDenoise, "FlipSigmas": FlipSigmas, From 178289650237d5e77b9919b9547c18d799b4cbb5 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Tue, 17 Sep 2024 23:21:52 +0800 Subject: [PATCH 2/8] Fix lambda_prev in corrector --- comfy/k_diffusion/sa_solver.py | 2 +- comfy/k_diffusion/sampling.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index 4f226c38..a475bc90 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -150,7 +150,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li sigma_list = sigma_prev_list + [sigma] lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)] lambda_t = lambda_list[0] - lambda_prev = lambda_list[1] + lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev) h = lambda_t - lambda_prev gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1bc3a27d..da2c0e0b 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1104,8 +1104,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Predictor step x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=noise, sigma=sigma) # Evaluation step denoised = model(x_p, sigma * s_in, **extra_args) @@ -1116,8 +1116,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Corrector step if corrector_order_used > 0: x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=noise, sigma=sigma) else: x = x_p @@ -1140,8 +1140,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Extra final step x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=0, sigma=sigmas[-1]) + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=0, sigma=sigmas[-1]) return x @torch.no_grad() From 587c93ebff781632217895db1ef5f5e59e3cdc5e Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 22 Sep 2024 18:27:13 +0800 Subject: [PATCH 3/8] Start lower_final_step in predictor earlier to stabilize --- comfy/k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index da2c0e0b..f4e39bfb 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1096,7 +1096,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F sigma_prev_list.append(sigma) else: # Lower order final - predictor_order_used = min(predictor_order, i, len(sigmas) - i) + predictor_order_used = min(predictor_order, i, len(sigmas) - i - 1) corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1) tau_val = tau(sigma) From c176ad8f50d051159f367bdba3b107bce0a3efe1 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Wed, 30 Oct 2024 01:16:34 +0800 Subject: [PATCH 4/8] Clean up and remove modifying zero sigma --- comfy/k_diffusion/sampling.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 605b1092..06dcc0b1 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1095,10 +1095,6 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if len(sigmas) <= 1: return x - if sigmas[-1] == 0: - sigmas = sigmas.clone() - sigmas[-1] = 0.001 - extra_args = {} if extra_args is None else extra_args if tau_func is None: model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1115,7 +1111,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] if i == 0: - # Init the initial values. + # Init the initial values denoised = model(x, sigma * s_in, **extra_args) model_prev_list.append(denoised) sigma_prev_list.append(sigma) @@ -1134,22 +1130,19 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Evaluation step denoised = model(x_p, sigma * s_in, **extra_args) - - # Update model_list model_prev_list.append(denoised) # Corrector step if corrector_order_used > 0: x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) - + noise=noise, sigma=sigma) else: x = x_p del noise, x_p - # Evaluation step if mode = pece and step != steps + # Evaluation step for PECE if corrector_order_used > 0 and pc_mode == 'PECE': del model_prev_list[-1] denoised = model(x, sigma * s_in, **extra_args) @@ -1163,10 +1156,13 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if callback is not None: callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) - # Extra final step - x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=0, sigma=sigmas[-1]) + if sigmas[-1] == 0: + # Denoising step + x = model_prev_list[-1] + else: + x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=0, sigma=sigmas[-1]) return x @torch.no_grad() From 6b68b6164458f0011c14c404451c3a685c84814e Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:55:29 +0800 Subject: [PATCH 5/8] Use default_noise_sampler instead --- comfy/k_diffusion/sa_solver.py | 6 ------ comfy/k_diffusion/sampling.py | 25 +++---------------------- comfy/samplers.py | 2 +- comfy_extras/nodes_custom_sampler.py | 7 +++---- 4 files changed, 7 insertions(+), 33 deletions(-) diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index a475bc90..cb4bece5 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -175,12 +175,6 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part return x_t -def device_noise_sampler(x, noise_device='gpu'): - if noise_device == "gpu": - return torch.randn_like(x) - else: - return torch.randn(x.shape, device='cpu').to(x.device) - # Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): if eta == 0: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 9fbc62bd..4adc828f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1154,7 +1154,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F end_sigma = model_sampling.percent_to_sigma(0.8) tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) tau = tau_func - noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_prev_list = [] @@ -1173,7 +1174,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1) tau_val = tau(sigma) - noise = None if tau_val == 0 else noise_sampler() + noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1]) # Predictor step x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, @@ -1226,26 +1227,6 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, ) -@torch.no_grad() -def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): - if len(sigmas) <= 1: - return x - noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler - return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, - predictor_order=predictor_order, corrector_order=corrector_order, - pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler, - ) - -@torch.no_grad() -def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): - if len(sigmas) <= 1: - return x - noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler - return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, - predictor_order=predictor_order, corrector_order=corrector_order, - pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, - ) - @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): extra_args = {} if extra_args is None else extra_args diff --git a/comfy/samplers.py b/comfy/samplers.py index cd011e29..bb54a5ca 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -687,7 +687,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"] + "ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 9c8eb4b9..03461d6c 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -423,7 +423,6 @@ class SamplerSASolver: "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), "eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}), - "noise_device": (["gpu", "cpu"],), } } RETURN_TYPES = ("SAMPLER",) @@ -431,16 +430,16 @@ class SamplerSASolver: FUNCTION = "get_sampler" - def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent, noise_device): + def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent): model_sampling = model.get_model_object('model_sampling') start_sigma = model_sampling.percent_to_sigma(eta_start_percent) end_sigma = model_sampling.percent_to_sigma(eta_end_percent) tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) if pc_mode == 'PEC': - sampler_name = "sa_solver" if noise_device == "cpu" else "sa_solver_gpu" + sampler_name = "sa_solver" else: - sampler_name = "sa_solver_pece" if noise_device == "cpu" else "sa_solver_pece_gpu" + sampler_name = "sa_solver_pece" sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func}) return (sampler, ) From 8a8327fa73b2856acf59ffc97a6837b6cb6dd5ef Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:58:27 +0800 Subject: [PATCH 6/8] change to direct return --- comfy/k_diffusion/sa_solver.py | 6 ++---- comfy/k_diffusion/sampling.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index cb4bece5..9cc81d1f 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -135,8 +135,7 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_ gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise - x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part - return x_t + return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma): """ @@ -172,8 +171,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise - x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part - return x_t + return torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part # Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 4adc828f..85617b68 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1211,12 +1211,10 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if sigmas[-1] == 0: # Denoising step - x = model_prev_list[-1] - else: - x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, + return model_prev_list[-1] + return sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, noise=0, sigma=sigmas[-1]) - return x @torch.no_grad() def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): From eb40f9377b29dd2532dbbef33601f6392fd05098 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Mon, 6 Jan 2025 19:31:19 +0800 Subject: [PATCH 7/8] Remove space --- comfy/k_diffusion/sa_solver.py | 34 ++++++++++++++-------------- comfy/k_diffusion/sampling.py | 10 ++++---- comfy_extras/nodes_custom_sampler.py | 2 +- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index 9cc81d1f..b784ecbf 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -17,27 +17,27 @@ def get_coefficients_exponential_positive(order, interval_start, interval_end, t interval_start_cov = (1 + tau ** 2) * interval_start if order == 0: - return (torch.exp(interval_end_cov) + return (torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2)) ) elif order == 1: - return (torch.exp(interval_end_cov) + return (torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2) ) elif order == 2: - return (torch.exp(interval_end_cov) - * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - - (interval_start_cov ** 2 - 2 * interval_start_cov + 2) + return (torch.exp(interval_end_cov) + * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) + - (interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau ** 2) ** 3) ) elif order == 3: - return (torch.exp(interval_end_cov) + return (torch.exp(interval_end_cov) * ((interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - - (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) + - (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau ** 2) ** 4) @@ -53,7 +53,7 @@ def lagrange_polynomial_coefficient(order, lambda_list): if order == 0: return [[1.0]] elif order == 1: - return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], + return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], [1.0 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]] elif order == 2: denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) @@ -79,12 +79,12 @@ def lagrange_polynomial_coefficient(order, lambda_list): (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2], [1.0 / denominator3, - (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, - (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3, (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3], [1.0 / denominator4, - (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2]) / denominator4, (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4] ] @@ -122,11 +122,11 @@ def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_ # ODE case # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) - gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t) - * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2)) + gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t) + * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2)) / (lambda_prev - lambda_list[1]) ) - gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t) + gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2)) / (lambda_prev - lambda_list[1]) ) @@ -152,7 +152,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev) h = lambda_t - lambda_prev gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau) - + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. # The added term is O(h^3). Empirically we find it will slightly improve the image quality. # ODE case @@ -166,7 +166,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li * (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2 * h)) ) - + for i in range(order): gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)] gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t) @@ -178,4 +178,4 @@ def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): if eta == 0: # Pure ODE return 0 - return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0 \ No newline at end of file + return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0 diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 85617b68..4c8d18f2 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1146,7 +1146,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, pc_mode="PEC", tau_func=None, noise_sampler=None): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args if tau_func is None: model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1172,7 +1172,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Lower order final predictor_order_used = min(predictor_order, i, len(sigmas) - i - 1) corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1) - + tau_val = tau(sigma) noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1]) @@ -1183,13 +1183,13 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Evaluation step denoised = model(x_p, sigma * s_in, **extra_args) - model_prev_list.append(denoised) + model_prev_list.append(denoised) # Corrector step if corrector_order_used > 0: x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) + noise=noise, sigma=sigma) else: x = x_p @@ -1205,7 +1205,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if len(model_prev_list) > max(predictor_order, corrector_order): del model_prev_list[0] del sigma_prev_list[0] - + if callback is not None: callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 03461d6c..3cffdf8a 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -435,7 +435,7 @@ class SamplerSASolver: start_sigma = model_sampling.percent_to_sigma(eta_start_percent) end_sigma = model_sampling.percent_to_sigma(eta_end_percent) tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) - + if pc_mode == 'PEC': sampler_name = "sa_solver" else: From 8d9dc98fbab0b7bee7100063db6fd9515ce9285c Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Mon, 6 Jan 2025 19:35:31 +0800 Subject: [PATCH 8/8] Remove more space --- comfy/k_diffusion/sampling.py | 2 +- comfy_extras/nodes_custom_sampler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 4c8d18f2..446c9f7b 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1217,7 +1217,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F noise=0, sigma=sigmas[-1]) @torch.no_grad() -def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): +def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None): if len(sigmas) <= 1: return x return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 3cffdf8a..d8094e12 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -413,7 +413,7 @@ class SamplerDPMAdaptative: "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, "s_noise":s_noise }) return (sampler, ) - + class SamplerSASolver: @classmethod def INPUT_TYPES(s):