mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Add res_multistep sampler from the cosmos code.
This sampler should work with all models.
This commit is contained in:
parent
b9d9bcba14
commit
90f349f93d
258
comfy/k_diffusion/res.py
Normal file
258
comfy/k_diffusion/res.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Copied from Nvidia Cosmos code.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Callable, List, Tuple, Optional, Any
|
||||||
|
import math
|
||||||
|
from tqdm.auto import trange
|
||||||
|
|
||||||
|
|
||||||
|
def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
ndims1 = x.ndim
|
||||||
|
ndims2 = y.ndim
|
||||||
|
|
||||||
|
if ndims1 < ndims2:
|
||||||
|
x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
|
||||||
|
elif ndims2 < ndims1:
|
||||||
|
y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def batch_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||||
|
x, y = common_broadcast(x, y)
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
def phi1(t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the first order phi function: (exp(t) - 1) / t.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Result of phi1 function.
|
||||||
|
"""
|
||||||
|
input_dtype = t.dtype
|
||||||
|
t = t.to(dtype=torch.float32)
|
||||||
|
return (torch.expm1(t) / t).to(dtype=input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def phi2(t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the second order phi function: (phi1(t) - 1) / t.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Result of phi2 function.
|
||||||
|
"""
|
||||||
|
input_dtype = t.dtype
|
||||||
|
t = t.to(dtype=torch.float32)
|
||||||
|
return ((phi1(t) - 1.0) / t).to(dtype=input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def res_x0_rk2_step(
|
||||||
|
x_s: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
s: torch.Tensor,
|
||||||
|
x0_s: torch.Tensor,
|
||||||
|
s1: torch.Tensor,
|
||||||
|
x0_s1: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform a residual-based 2nd order Runge-Kutta step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_s: Current state tensor.
|
||||||
|
t: Target time tensor.
|
||||||
|
s: Current time tensor.
|
||||||
|
x0_s: Prediction at current time.
|
||||||
|
s1: Intermediate time tensor.
|
||||||
|
x0_s1: Prediction at intermediate time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Updated state tensor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If step size is too small.
|
||||||
|
"""
|
||||||
|
s = -torch.log(s)
|
||||||
|
t = -torch.log(t)
|
||||||
|
m = -torch.log(s1)
|
||||||
|
|
||||||
|
dt = t - s
|
||||||
|
assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
|
||||||
|
assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
|
||||||
|
|
||||||
|
c2 = (m - s) / dt
|
||||||
|
phi1_val, phi2_val = phi1(-dt), phi2(-dt)
|
||||||
|
|
||||||
|
# Handle edge case where t = s = m
|
||||||
|
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
||||||
|
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
||||||
|
|
||||||
|
return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1))
|
||||||
|
|
||||||
|
|
||||||
|
def reg_x0_euler_step(
|
||||||
|
x_s: torch.Tensor,
|
||||||
|
s: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
x0_s: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Perform a regularized Euler step based on x0 prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_s: Current state tensor.
|
||||||
|
s: Current time tensor.
|
||||||
|
t: Target time tensor.
|
||||||
|
x0_s: Prediction at current time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Tensor, Tensor]: Updated state tensor and current prediction.
|
||||||
|
"""
|
||||||
|
coef_x0 = (s - t) / s
|
||||||
|
coef_xs = t / s
|
||||||
|
return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s
|
||||||
|
|
||||||
|
|
||||||
|
def order2_fn(
|
||||||
|
x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
impl the second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||||
|
Adams Bashforth approach!
|
||||||
|
"""
|
||||||
|
if x0_preds:
|
||||||
|
x0_s1, s1 = x0_preds[0]
|
||||||
|
x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1)
|
||||||
|
else:
|
||||||
|
x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0]
|
||||||
|
return x_t, [(x0_s, s)]
|
||||||
|
|
||||||
|
|
||||||
|
class SolverConfig:
|
||||||
|
is_multi: bool = True
|
||||||
|
rk: str = "2mid"
|
||||||
|
multistep: str = "2ab"
|
||||||
|
s_churn: float = 0.0
|
||||||
|
s_t_max: float = float("inf")
|
||||||
|
s_t_min: float = 0.0
|
||||||
|
s_noise: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any, disable=None) -> Any:
|
||||||
|
"""
|
||||||
|
Implements a for loop with a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lower: Lower bound of the loop (inclusive).
|
||||||
|
upper: Upper bound of the loop (exclusive).
|
||||||
|
body_fun: Function to be applied in each iteration.
|
||||||
|
init_val: Initial value for the loop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final result after all iterations.
|
||||||
|
"""
|
||||||
|
val = init_val
|
||||||
|
for i in trange(lower, upper, disable=disable):
|
||||||
|
val = body_fun(i, val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def differential_equation_solver(
|
||||||
|
x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
sigmas_L: torch.Tensor,
|
||||||
|
solver_cfg: SolverConfig,
|
||||||
|
noise_sampler,
|
||||||
|
callback=None,
|
||||||
|
disable=None,
|
||||||
|
) -> Callable[[torch.Tensor], torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Creates a differential equation solver function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x0_fn: Function to compute x0 prediction.
|
||||||
|
sigmas_L: Tensor of sigma values with shape [L,].
|
||||||
|
solver_cfg: Configuration for the solver.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that solves the differential equation.
|
||||||
|
"""
|
||||||
|
num_step = len(sigmas_L) - 1
|
||||||
|
|
||||||
|
# if solver_cfg.is_multi:
|
||||||
|
# update_step_fn = get_multi_step_fn(solver_cfg.multistep)
|
||||||
|
# else:
|
||||||
|
# update_step_fn = get_runge_kutta_fn(solver_cfg.rk)
|
||||||
|
update_step_fn = order2_fn
|
||||||
|
|
||||||
|
eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1)
|
||||||
|
|
||||||
|
def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Samples from the differential equation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_xT_B_StateShape: Input tensor with shape [B, StateShape].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor with shape [B, StateShape].
|
||||||
|
"""
|
||||||
|
ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
def step_fn(
|
||||||
|
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
||||||
|
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
||||||
|
input_x_B_StateShape, x0_preds = state
|
||||||
|
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
||||||
|
|
||||||
|
if sigma_next_0 == 0:
|
||||||
|
output_x_B_StateShape = x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
|
||||||
|
else:
|
||||||
|
# algorithm 2: line 4-6
|
||||||
|
if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max and eta > 0:
|
||||||
|
hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0
|
||||||
|
input_x_B_StateShape = input_x_B_StateShape + (
|
||||||
|
hat_sigma_cur_0**2 - sigma_cur_0**2
|
||||||
|
).sqrt() * solver_cfg.s_noise * noise_sampler(sigma_cur_0, sigma_next_0) # torch.randn_like(input_x_B_StateShape)
|
||||||
|
sigma_cur_0 = hat_sigma_cur_0
|
||||||
|
|
||||||
|
if solver_cfg.is_multi:
|
||||||
|
x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
|
||||||
|
output_x_B_StateShape, x0_preds = update_step_fn(
|
||||||
|
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_x_B_StateShape, x0_preds = update_step_fn(
|
||||||
|
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': input_x_B_StateShape, 'i': i_th, 'sigma': sigma_cur_0, 'sigma_hat': sigma_cur_0, 'denoised': x0_pred_B_StateShape})
|
||||||
|
|
||||||
|
return output_x_B_StateShape, x0_preds
|
||||||
|
|
||||||
|
x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None], disable=disable)
|
||||||
|
return x_at_eps
|
||||||
|
|
||||||
|
return sample_fn
|
@ -8,6 +8,7 @@ from tqdm.auto import trange, tqdm
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
|
from . import res
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
@ -1265,3 +1266,20 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
old_uncond_denoised = uncond_denoised
|
old_uncond_denoised = uncond_denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
x0_func = lambda x, sigma: model(x, sigma, **extra_args)
|
||||||
|
|
||||||
|
solver_cfg = res.SolverConfig()
|
||||||
|
solver_cfg.s_churn = s_churn
|
||||||
|
solver_cfg.s_t_max = s_tmax
|
||||||
|
solver_cfg.s_t_min = s_tmin
|
||||||
|
solver_cfg.s_noise = s_noise
|
||||||
|
|
||||||
|
x = res.differential_equation_solver(x0_func, sigmas, solver_cfg, noise_sampler, callback=callback, disable=disable)(x)
|
||||||
|
return x
|
||||||
|
@ -687,7 +687,7 @@ class Sampler:
|
|||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis", "res_multistep"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
Loading…
Reference in New Issue
Block a user