2024-08-17 15:29:07 +00:00
# Modify from: https://github.com/scxue/SA-Solver
# MIT license
import torch
def get_coefficients_exponential_positive ( order , interval_start , interval_end , tau ) :
"""
Calculate the integral of exp ( x ( 1 + tau ^ 2 ) ) * x ^ order dx from interval_start to interval_end
For calculating the coefficient of gradient terms after the lagrange interpolation ,
see Eq . ( 15 ) and Eq . ( 18 ) in SA - Solver paper https : / / arxiv . org / pdf / 2309.05019 . pdf
For data_prediction formula .
"""
assert order in [ 0 , 1 , 2 , 3 ] , " order is only supported for 0, 1, 2 and 3 "
# after change of variable(cov)
interval_end_cov = ( 1 + tau * * 2 ) * interval_end
interval_start_cov = ( 1 + tau * * 2 ) * interval_start
if order == 0 :
return ( torch . exp ( interval_end_cov )
* ( 1 - torch . exp ( - ( interval_end_cov - interval_start_cov ) ) )
/ ( ( 1 + tau * * 2 ) )
)
elif order == 1 :
return ( torch . exp ( interval_end_cov )
* ( ( interval_end_cov - 1 ) - ( interval_start_cov - 1 ) * torch . exp ( - ( interval_end_cov - interval_start_cov ) ) )
/ ( ( 1 + tau * * 2 ) * * 2 )
)
elif order == 2 :
return ( torch . exp ( interval_end_cov )
* ( ( interval_end_cov * * 2 - 2 * interval_end_cov + 2 )
- ( interval_start_cov * * 2 - 2 * interval_start_cov + 2 )
* torch . exp ( - ( interval_end_cov - interval_start_cov ) )
)
/ ( ( 1 + tau * * 2 ) * * 3 )
)
elif order == 3 :
return ( torch . exp ( interval_end_cov )
* ( ( interval_end_cov * * 3 - 3 * interval_end_cov * * 2 + 6 * interval_end_cov - 6 )
- ( interval_start_cov * * 3 - 3 * interval_start_cov * * 2 + 6 * interval_start_cov - 6 )
* torch . exp ( - ( interval_end_cov - interval_start_cov ) )
)
/ ( ( 1 + tau * * 2 ) * * 4 )
)
def lagrange_polynomial_coefficient ( order , lambda_list ) :
"""
Calculate the coefficient of lagrange polynomial
For lagrange interpolation
"""
assert order in [ 0 , 1 , 2 , 3 ]
assert order == len ( lambda_list ) - 1
if order == 0 :
return [ [ 1.0 ] ]
elif order == 1 :
return [ [ 1.0 / ( lambda_list [ 0 ] - lambda_list [ 1 ] ) , - lambda_list [ 1 ] / ( lambda_list [ 0 ] - lambda_list [ 1 ] ) ] ,
[ 1.0 / ( lambda_list [ 1 ] - lambda_list [ 0 ] ) , - lambda_list [ 0 ] / ( lambda_list [ 1 ] - lambda_list [ 0 ] ) ] ]
elif order == 2 :
denominator1 = ( lambda_list [ 0 ] - lambda_list [ 1 ] ) * ( lambda_list [ 0 ] - lambda_list [ 2 ] )
denominator2 = ( lambda_list [ 1 ] - lambda_list [ 0 ] ) * ( lambda_list [ 1 ] - lambda_list [ 2 ] )
denominator3 = ( lambda_list [ 2 ] - lambda_list [ 0 ] ) * ( lambda_list [ 2 ] - lambda_list [ 1 ] )
return [ [ 1.0 / denominator1 , ( - lambda_list [ 1 ] - lambda_list [ 2 ] ) / denominator1 , lambda_list [ 1 ] * lambda_list [ 2 ] / denominator1 ] ,
[ 1.0 / denominator2 , ( - lambda_list [ 0 ] - lambda_list [ 2 ] ) / denominator2 , lambda_list [ 0 ] * lambda_list [ 2 ] / denominator2 ] ,
[ 1.0 / denominator3 , ( - lambda_list [ 0 ] - lambda_list [ 1 ] ) / denominator3 , lambda_list [ 0 ] * lambda_list [ 1 ] / denominator3 ]
]
elif order == 3 :
denominator1 = ( lambda_list [ 0 ] - lambda_list [ 1 ] ) * ( lambda_list [ 0 ] - lambda_list [ 2 ] ) * ( lambda_list [ 0 ] - lambda_list [ 3 ] )
denominator2 = ( lambda_list [ 1 ] - lambda_list [ 0 ] ) * ( lambda_list [ 1 ] - lambda_list [ 2 ] ) * ( lambda_list [ 1 ] - lambda_list [ 3 ] )
denominator3 = ( lambda_list [ 2 ] - lambda_list [ 0 ] ) * ( lambda_list [ 2 ] - lambda_list [ 1 ] ) * ( lambda_list [ 2 ] - lambda_list [ 3 ] )
denominator4 = ( lambda_list [ 3 ] - lambda_list [ 0 ] ) * ( lambda_list [ 3 ] - lambda_list [ 1 ] ) * ( lambda_list [ 3 ] - lambda_list [ 2 ] )
return [ [ 1.0 / denominator1 ,
( - lambda_list [ 1 ] - lambda_list [ 2 ] - lambda_list [ 3 ] ) / denominator1 ,
( lambda_list [ 1 ] * lambda_list [ 2 ] + lambda_list [ 1 ] * lambda_list [ 3 ] + lambda_list [ 2 ] * lambda_list [ 3 ] ) / denominator1 ,
( - lambda_list [ 1 ] * lambda_list [ 2 ] * lambda_list [ 3 ] ) / denominator1 ] ,
[ 1.0 / denominator2 ,
( - lambda_list [ 0 ] - lambda_list [ 2 ] - lambda_list [ 3 ] ) / denominator2 ,
( lambda_list [ 0 ] * lambda_list [ 2 ] + lambda_list [ 0 ] * lambda_list [ 3 ] + lambda_list [ 2 ] * lambda_list [ 3 ] ) / denominator2 ,
( - lambda_list [ 0 ] * lambda_list [ 2 ] * lambda_list [ 3 ] ) / denominator2 ] ,
[ 1.0 / denominator3 ,
( - lambda_list [ 0 ] - lambda_list [ 1 ] - lambda_list [ 3 ] ) / denominator3 ,
( lambda_list [ 0 ] * lambda_list [ 1 ] + lambda_list [ 0 ] * lambda_list [ 3 ] + lambda_list [ 1 ] * lambda_list [ 3 ] ) / denominator3 ,
( - lambda_list [ 0 ] * lambda_list [ 1 ] * lambda_list [ 3 ] ) / denominator3 ] ,
[ 1.0 / denominator4 ,
( - lambda_list [ 0 ] - lambda_list [ 1 ] - lambda_list [ 2 ] ) / denominator4 ,
( lambda_list [ 0 ] * lambda_list [ 1 ] + lambda_list [ 0 ] * lambda_list [ 2 ] + lambda_list [ 1 ] * lambda_list [ 2 ] ) / denominator4 ,
( - lambda_list [ 0 ] * lambda_list [ 1 ] * lambda_list [ 2 ] ) / denominator4 ]
]
def get_coefficients_fn ( order , interval_start , interval_end , lambda_list , tau ) :
"""
Calculate the coefficient of gradients .
"""
assert order in [ 1 , 2 , 3 , 4 ]
assert order == len ( lambda_list ) , ' the length of lambda list must be equal to the order '
lagrange_coefficient = lagrange_polynomial_coefficient ( order - 1 , lambda_list )
coefficients = [ sum ( lagrange_coefficient [ i ] [ j ] * get_coefficients_exponential_positive ( order - 1 - j , interval_start , interval_end , tau )
for j in range ( order ) )
for i in range ( order ) ]
assert len ( coefficients ) == order , ' the length of coefficients does not match the order '
return coefficients
def adams_bashforth_update_few_steps ( order , x , tau , model_prev_list , sigma_prev_list , noise , sigma ) :
"""
SA - Predictor , with the " rescaling " trick in Appendix D in SA - Solver paper https : / / arxiv . org / pdf / 2309.05019 . pdf
"""
assert order in [ 1 , 2 , 3 , 4 ] , " order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4 "
t_fn = lambda sigma : sigma . log ( ) . neg ( )
sigma_prev = sigma_prev_list [ - 1 ]
gradient_part = torch . zeros_like ( x )
lambda_list = [ t_fn ( sigma_prev_list [ - ( i + 1 ) ] ) for i in range ( order ) ]
lambda_t = t_fn ( sigma )
lambda_prev = lambda_list [ 0 ]
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn ( order , lambda_prev , lambda_t , lambda_list , tau )
if order == 2 : ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
# ODE case
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
gradient_coefficients [ 0 ] + = ( 1.0 * torch . exp ( ( 1 + tau * * 2 ) * lambda_t )
* ( h * * 2 / 2 - ( h * ( 1 + tau * * 2 ) - 1 + torch . exp ( ( 1 + tau * * 2 ) * ( - h ) ) ) / ( ( 1 + tau * * 2 ) * * 2 ) )
/ ( lambda_prev - lambda_list [ 1 ] )
)
gradient_coefficients [ 1 ] - = ( 1.0 * torch . exp ( ( 1 + tau * * 2 ) * lambda_t )
* ( h * * 2 / 2 - ( h * ( 1 + tau * * 2 ) - 1 + torch . exp ( ( 1 + tau * * 2 ) * ( - h ) ) ) / ( ( 1 + tau * * 2 ) * * 2 ) )
/ ( lambda_prev - lambda_list [ 1 ] )
)
for i in range ( order ) :
gradient_part + = gradient_coefficients [ i ] * model_prev_list [ - ( i + 1 ) ]
gradient_part * = ( 1 + tau * * 2 ) * sigma * torch . exp ( - tau * * 2 * lambda_t )
noise_part = 0 if tau == 0 else sigma * torch . sqrt ( 1. - torch . exp ( - 2 * tau * * 2 * h ) ) * noise
x_t = torch . exp ( - tau * * 2 * h ) * ( sigma / sigma_prev ) * x + gradient_part + noise_part
return x_t
def adams_moulton_update_few_steps ( order , x , tau , model_prev_list , sigma_prev_list , noise , sigma ) :
"""
SA - Corrector , with the " rescaling " trick in Appendix D in SA - Solver paper https : / / arxiv . org / pdf / 2309.05019 . pdf
"""
assert order in [ 1 , 2 , 3 , 4 ] , " order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4 "
t_fn = lambda sigma : sigma . log ( ) . neg ( )
sigma_prev = sigma_prev_list [ - 1 ]
gradient_part = torch . zeros_like ( x )
sigma_list = sigma_prev_list + [ sigma ]
lambda_list = [ t_fn ( sigma_list [ - ( i + 1 ) ] ) for i in range ( order ) ]
lambda_t = lambda_list [ 0 ]
2024-09-17 15:21:52 +00:00
lambda_prev = lambda_list [ 1 ] if order > = 2 else t_fn ( sigma_prev )
2024-08-17 15:29:07 +00:00
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn ( order , lambda_prev , lambda_t , lambda_list , tau )
if order == 2 : ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
# ODE case
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
gradient_coefficients [ 0 ] + = ( 1.0 * torch . exp ( ( 1 + tau * * 2 ) * lambda_t )
* ( h / 2 - ( h * ( 1 + tau * * 2 ) - 1 + torch . exp ( ( 1 + tau * * 2 ) * ( - h ) ) )
/ ( ( 1 + tau * * 2 ) * * 2 * h ) )
)
gradient_coefficients [ 1 ] - = ( 1.0 * torch . exp ( ( 1 + tau * * 2 ) * lambda_t )
* ( h / 2 - ( h * ( 1 + tau * * 2 ) - 1 + torch . exp ( ( 1 + tau * * 2 ) * ( - h ) ) )
/ ( ( 1 + tau * * 2 ) * * 2 * h ) )
)
for i in range ( order ) :
gradient_part + = gradient_coefficients [ i ] * model_prev_list [ - ( i + 1 ) ]
gradient_part * = ( 1 + tau * * 2 ) * sigma * torch . exp ( - tau * * 2 * lambda_t )
noise_part = 0 if tau == 0 else sigma * torch . sqrt ( 1. - torch . exp ( - 2 * tau * * 2 * h ) ) * noise
x_t = torch . exp ( - tau * * 2 * h ) * ( sigma / sigma_prev ) * x + gradient_part + noise_part
return x_t
# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
def default_tau_func ( sigma , eta , eta_start_sigma , eta_end_sigma ) :
if eta == 0 :
# Pure ODE
return 0
return eta if eta_end_sigma < = sigma < = eta_start_sigma else 0