Fix lambda_prev in corrector

This commit is contained in:
chaObserv 2024-09-17 23:21:52 +08:00
parent 92a630e737
commit 1782896502
2 changed files with 7 additions and 7 deletions

View File

@ -150,7 +150,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
sigma_list = sigma_prev_list + [sigma] sigma_list = sigma_prev_list + [sigma]
lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)] lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)]
lambda_t = lambda_list[0] lambda_t = lambda_list[0]
lambda_prev = lambda_list[1] lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev)
h = lambda_t - lambda_prev h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau) gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)

View File

@ -1104,8 +1104,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor step # Predictor step
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val, x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma) noise=noise, sigma=sigma)
# Evaluation step # Evaluation step
denoised = model(x_p, sigma * s_in, **extra_args) denoised = model(x_p, sigma * s_in, **extra_args)
@ -1116,8 +1116,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Corrector step # Corrector step
if corrector_order_used > 0: if corrector_order_used > 0:
x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val, x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma) noise=noise, sigma=sigma)
else: else:
x = x_p x = x_p
@ -1140,8 +1140,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Extra final step # Extra final step
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0, x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list, model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1]) noise=0, sigma=sigmas[-1])
return x return x
@torch.no_grad() @torch.no_grad()