Use default_noise_sampler instead

This commit is contained in:
chaObserv 2025-01-06 18:55:29 +08:00
parent 812dc34f46
commit 6b68b61644
4 changed files with 7 additions and 33 deletions

View File

@ -175,12 +175,6 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
return x_t return x_t
def device_noise_sampler(x, noise_device='gpu'):
if noise_device == "gpu":
return torch.randn_like(x)
else:
return torch.randn(x.shape, device='cpu').to(x.device)
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract # Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma): def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
if eta == 0: if eta == 0:

View File

@ -1154,7 +1154,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
end_sigma = model_sampling.percent_to_sigma(0.8) end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
tau = tau_func tau = tau_func
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_prev_list = [] sigma_prev_list = []
@ -1173,7 +1174,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1) corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
tau_val = tau(sigma) tau_val = tau(sigma)
noise = None if tau_val == 0 else noise_sampler() noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1])
# Predictor step # Predictor step
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
@ -1226,26 +1227,6 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler, pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
) )
@torch.no_grad()
def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
predictor_order=predictor_order, corrector_order=corrector_order,
pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler,
)
@torch.no_grad()
def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
predictor_order=predictor_order, corrector_order=corrector_order,
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
)
@torch.no_grad() @torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args

View File

@ -687,7 +687,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"] "ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_pece"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -423,7 +423,6 @@ class SamplerSASolver:
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}), "eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}),
"noise_device": (["gpu", "cpu"],),
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
@ -431,16 +430,16 @@ class SamplerSASolver:
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent, noise_device): def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent):
model_sampling = model.get_model_object('model_sampling') model_sampling = model.get_model_object('model_sampling')
start_sigma = model_sampling.percent_to_sigma(eta_start_percent) start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
end_sigma = model_sampling.percent_to_sigma(eta_end_percent) end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma) tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
if pc_mode == 'PEC': if pc_mode == 'PEC':
sampler_name = "sa_solver" if noise_device == "cpu" else "sa_solver_gpu" sampler_name = "sa_solver"
else: else:
sampler_name = "sa_solver_pece" if noise_device == "cpu" else "sa_solver_pece_gpu" sampler_name = "sa_solver_pece"
sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func}) sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func})
return (sampler, ) return (sampler, )