diff --git a/comfy/k_diffusion/res.py b/comfy/k_diffusion/res.py deleted file mode 100644 index 6caedec3..00000000 --- a/comfy/k_diffusion/res.py +++ /dev/null @@ -1,258 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copied from Nvidia Cosmos code. - -import torch -from torch import Tensor -from typing import Callable, List, Tuple, Optional, Any -import math -from tqdm.auto import trange - - -def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: - ndims1 = x.ndim - ndims2 = y.ndim - - if ndims1 < ndims2: - x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) - elif ndims2 < ndims1: - y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) - - return x, y - - -def batch_mul(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x * y - - -def phi1(t: torch.Tensor) -> torch.Tensor: - """ - Compute the first order phi function: (exp(t) - 1) / t. - - Args: - t: Input tensor. - - Returns: - Tensor: Result of phi1 function. - """ - input_dtype = t.dtype - t = t.to(dtype=torch.float32) - return (torch.expm1(t) / t).to(dtype=input_dtype) - - -def phi2(t: torch.Tensor) -> torch.Tensor: - """ - Compute the second order phi function: (phi1(t) - 1) / t. - - Args: - t: Input tensor. - - Returns: - Tensor: Result of phi2 function. - """ - input_dtype = t.dtype - t = t.to(dtype=torch.float32) - return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) - - -def res_x0_rk2_step( - x_s: torch.Tensor, - t: torch.Tensor, - s: torch.Tensor, - x0_s: torch.Tensor, - s1: torch.Tensor, - x0_s1: torch.Tensor, -) -> torch.Tensor: - """ - Perform a residual-based 2nd order Runge-Kutta step. - - Args: - x_s: Current state tensor. - t: Target time tensor. - s: Current time tensor. - x0_s: Prediction at current time. - s1: Intermediate time tensor. - x0_s1: Prediction at intermediate time. - - Returns: - Tensor: Updated state tensor. - - Raises: - AssertionError: If step size is too small. - """ - s = -torch.log(s) - t = -torch.log(t) - m = -torch.log(s1) - - dt = t - s - assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" - assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" - - c2 = (m - s) / dt - phi1_val, phi2_val = phi1(-dt), phi2(-dt) - - # Handle edge case where t = s = m - b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) - b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) - - return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) - - -def reg_x0_euler_step( - x_s: torch.Tensor, - s: torch.Tensor, - t: torch.Tensor, - x0_s: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a regularized Euler step based on x0 prediction. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_s: Prediction at current time. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current prediction. - """ - coef_x0 = (s - t) / s - coef_xs = t / s - return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s - - -def order2_fn( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor -) -> Tuple[torch.Tensor, List[torch.Tensor]]: - """ - impl the second order multistep method in https://arxiv.org/pdf/2308.02157 - Adams Bashforth approach! - """ - if x0_preds: - x0_s1, s1 = x0_preds[0] - x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) - else: - x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] - return x_t, [(x0_s, s)] - - -class SolverConfig: - is_multi: bool = True - rk: str = "2mid" - multistep: str = "2ab" - s_churn: float = 0.0 - s_t_max: float = float("inf") - s_t_min: float = 0.0 - s_noise: float = 1.0 - - -def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any, disable=None) -> Any: - """ - Implements a for loop with a function. - - Args: - lower: Lower bound of the loop (inclusive). - upper: Upper bound of the loop (exclusive). - body_fun: Function to be applied in each iteration. - init_val: Initial value for the loop. - - Returns: - The final result after all iterations. - """ - val = init_val - for i in trange(lower, upper, disable=disable): - val = body_fun(i, val) - return val - - -def differential_equation_solver( - x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - sigmas_L: torch.Tensor, - solver_cfg: SolverConfig, - noise_sampler, - callback=None, - disable=None, -) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Creates a differential equation solver function. - - Args: - x0_fn: Function to compute x0 prediction. - sigmas_L: Tensor of sigma values with shape [L,]. - solver_cfg: Configuration for the solver. - - Returns: - A function that solves the differential equation. - """ - num_step = len(sigmas_L) - 1 - - # if solver_cfg.is_multi: - # update_step_fn = get_multi_step_fn(solver_cfg.multistep) - # else: - # update_step_fn = get_runge_kutta_fn(solver_cfg.rk) - update_step_fn = order2_fn - - eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) - - def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: - """ - Samples from the differential equation. - - Args: - input_xT_B_StateShape: Input tensor with shape [B, StateShape]. - - Returns: - Output tensor with shape [B, StateShape]. - """ - ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float32) - - def step_fn( - i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] - ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: - input_x_B_StateShape, x0_preds = state - sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] - - if sigma_next_0 == 0: - output_x_B_StateShape = x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) - else: - # algorithm 2: line 4-6 - if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max and eta > 0: - hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 - input_x_B_StateShape = input_x_B_StateShape + ( - hat_sigma_cur_0**2 - sigma_cur_0**2 - ).sqrt() * solver_cfg.s_noise * noise_sampler(sigma_cur_0, sigma_next_0) # torch.randn_like(input_x_B_StateShape) - sigma_cur_0 = hat_sigma_cur_0 - - if solver_cfg.is_multi: - x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) - output_x_B_StateShape, x0_preds = update_step_fn( - input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds - ) - else: - output_x_B_StateShape, x0_preds = update_step_fn( - input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn - ) - - if callback is not None: - callback({'x': input_x_B_StateShape, 'i': i_th, 'sigma': sigma_cur_0, 'sigma_hat': sigma_cur_0, 'denoised': x0_pred_B_StateShape}) - - return output_x_B_StateShape, x0_preds - - x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None], disable=disable) - return x_at_eps - - return sample_fn diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 3a98e6a7..13ae272f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -8,7 +8,6 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis -from . import res import comfy.model_patcher import comfy.model_sampling @@ -1268,18 +1267,72 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis return x @torch.no_grad() -def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): +def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + phi1_fn = lambda t: torch.expm1(t) / t + phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t - x0_func = lambda x, sigma: model(x, sigma, **extra_args) + old_denoised = None + uncond_denoised = None + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] - solver_cfg = res.SolverConfig() - solver_cfg.s_churn = s_churn - solver_cfg.s_t_max = s_tmax - solver_cfg.s_t_min = s_tmin - solver_cfg.s_noise = s_noise + if cfg_pp: + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) - x = res.differential_equation_solver(x0_func, sigmas, solver_cfg, noise_sampler, callback=callback, disable=disable)(x) + for i in trange(len(sigmas) - 1, disable=disable): + if s_churn > 0: + gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + sigma_hat = sigmas[i] * (gamma + 1) + else: + gamma = 0 + sigma_hat = sigmas[i] + + if gamma > 0: + eps = torch.randn_like(x) * s_noise + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + if callback is not None: + callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}) + if sigmas[i + 1] == 0 or old_denoised is None: + # Euler method + if cfg_pp: + d = to_d(x, sigma_hat, uncond_denoised) + x = denoised + d * sigmas[i + 1] + else: + d = to_d(x, sigma_hat, denoised) + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + else: + # Second order multistep method in https://arxiv.org/pdf/2308.02157 + t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1]) + h = t_next - t + c2 = (t_prev - t) / h + + phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) + b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) + b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + + if cfg_pp: + x = x + (denoised - uncond_denoised) + + x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised) + + old_denoised = denoised return x + +@torch.no_grad() +def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False) + +@torch.no_grad() +def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8f25935d..c508a3a4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -687,7 +687,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep"] + "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):