diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py index 4f226c38..a475bc90 100644 --- a/comfy/k_diffusion/sa_solver.py +++ b/comfy/k_diffusion/sa_solver.py @@ -150,7 +150,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li sigma_list = sigma_prev_list + [sigma] lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)] lambda_t = lambda_list[0] - lambda_prev = lambda_list[1] + lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev) h = lambda_t - lambda_prev gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1bc3a27d..da2c0e0b 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1104,8 +1104,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Predictor step x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=noise, sigma=sigma) # Evaluation step denoised = model(x_p, sigma * s_in, **extra_args) @@ -1116,8 +1116,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Corrector step if corrector_order_used > 0: x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=noise, sigma=sigma) + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=noise, sigma=sigma) else: x = x_p @@ -1140,8 +1140,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Extra final step x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, - model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, - noise=0, sigma=sigmas[-1]) + model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, + noise=0, sigma=sigmas[-1]) return x @torch.no_grad()