mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use default_noise_sampler instead
This commit is contained in:
parent
812dc34f46
commit
6b68b61644
@ -175,12 +175,6 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
|
|||||||
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def device_noise_sampler(x, noise_device='gpu'):
|
|
||||||
if noise_device == "gpu":
|
|
||||||
return torch.randn_like(x)
|
|
||||||
else:
|
|
||||||
return torch.randn(x.shape, device='cpu').to(x.device)
|
|
||||||
|
|
||||||
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
|
||||||
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
|
||||||
if eta == 0:
|
if eta == 0:
|
||||||
|
@ -1154,7 +1154,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
end_sigma = model_sampling.percent_to_sigma(0.8)
|
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||||
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
||||||
tau = tau_func
|
tau = tau_func
|
||||||
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
sigma_prev_list = []
|
sigma_prev_list = []
|
||||||
@ -1173,7 +1174,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
|
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)
|
||||||
|
|
||||||
tau_val = tau(sigma)
|
tau_val = tau(sigma)
|
||||||
noise = None if tau_val == 0 else noise_sampler()
|
noise = None if tau_val == 0 else noise_sampler(sigma, sigmas[i + 1])
|
||||||
|
|
||||||
# Predictor step
|
# Predictor step
|
||||||
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
|
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
|
||||||
@ -1226,26 +1227,6 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa
|
|||||||
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
|
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
|
||||||
if len(sigmas) <= 1:
|
|
||||||
return x
|
|
||||||
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
|
|
||||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
|
|
||||||
predictor_order=predictor_order, corrector_order=corrector_order,
|
|
||||||
pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
|
|
||||||
if len(sigmas) <= 1:
|
|
||||||
return x
|
|
||||||
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
|
|
||||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
|
|
||||||
predictor_order=predictor_order, corrector_order=corrector_order,
|
|
||||||
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
@ -687,7 +687,7 @@ class Sampler:
|
|||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"]
|
"ipndm", "ipndm_v", "deis", "sa_solver", "sa_solver_pece"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
@ -423,7 +423,6 @@ class SamplerSASolver:
|
|||||||
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||||
"eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"eta_start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"eta_end_percent": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"noise_device": (["gpu", "cpu"],),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SAMPLER",)
|
RETURN_TYPES = ("SAMPLER",)
|
||||||
@ -431,16 +430,16 @@ class SamplerSASolver:
|
|||||||
|
|
||||||
FUNCTION = "get_sampler"
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent, noise_device):
|
def get_sampler(self, model, pc_mode, eta, eta_start_percent, eta_end_percent):
|
||||||
model_sampling = model.get_model_object('model_sampling')
|
model_sampling = model.get_model_object('model_sampling')
|
||||||
start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
|
start_sigma = model_sampling.percent_to_sigma(eta_start_percent)
|
||||||
end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
|
end_sigma = model_sampling.percent_to_sigma(eta_end_percent)
|
||||||
tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
tau_func = partial(sa_solver.default_tau_func, eta=eta, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
|
||||||
|
|
||||||
if pc_mode == 'PEC':
|
if pc_mode == 'PEC':
|
||||||
sampler_name = "sa_solver" if noise_device == "cpu" else "sa_solver_gpu"
|
sampler_name = "sa_solver"
|
||||||
else:
|
else:
|
||||||
sampler_name = "sa_solver_pece" if noise_device == "cpu" else "sa_solver_pece_gpu"
|
sampler_name = "sa_solver_pece"
|
||||||
sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func})
|
sampler = comfy.samplers.ksampler(sampler_name, {"tau_func": tau_func})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user