diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 25051b33..07553627 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention import model_management +from . import tomesd if model_management.xformers_enabled(): import xformers @@ -508,8 +509,18 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): - x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x - x = self.attn2(self.norm2(x), context=context) + x + n = self.norm1(x) + if "tomesd" in transformer_options: + m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) + n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + else: + n = self.attn1(n, context=context if self.disable_self_attn else None) + + x += n + n = self.norm2(x) + n = self.attn2(n, context=context) + + x += n x = self.ff(self.norm3(x)) + x return x diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py new file mode 100644 index 00000000..5bf1acec --- /dev/null +++ b/comfy/ldm/modules/tomesd.py @@ -0,0 +1,117 @@ + + +import torch +from typing import Tuple, Callable +import math + +def do_nothing(x: torch.Tensor, mode:str=None): + return x + + +def bipartite_soft_matching_random2d(metric: torch.Tensor, + w: int, h: int, sx: int, sy: int, r: int, + no_rand: bool = False) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + """ + B, N, _ = metric.shape + + if r <= 0: + return do_nothing, do_nothing + + with torch.no_grad(): + + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device) + + if no_rand: + rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) + + idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) + idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) + rand_idx = idx_buffer.argsort(dim=1) + + num_dst = int((1 / (sx*sy)) * N) + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def get_functions(x, ratio, original_shape): + b, c, original_h, original_w = original_shape + original_tokens = original_h * original_w + downsample = int(math.sqrt(original_tokens // x.shape[1])) + stride_x = 2 + stride_y = 2 + max_downsample = 1 + + if downsample <= max_downsample: + w = original_w // downsample + h = original_h // downsample + r = int(x.shape[1] * ratio) + no_rand = True + m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) + return m, u + + nothing = lambda y: y + return nothing, nothing diff --git a/comfy/samplers.py b/comfy/samplers.py index 40d5d332..15e78bbd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out['c_concat'] = [torch.cat(c_concat)] return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -195,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) return uncond + (cond - uncond) * cond_scale @@ -212,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) return out @@ -221,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) if denoise_mask is not None: out *= denoise_mask @@ -333,7 +333,7 @@ class KSampler: "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] - def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): + def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v": @@ -353,6 +353,7 @@ class KSampler: self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) self.denoise = denoise + self.model_options = model_options def _calculate_sigmas(self, steps): sigmas = None @@ -421,7 +422,7 @@ class KSampler: else: precision_scope = contextlib.nullcontext - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None if hasattr(self.model, 'concat_keys'): diff --git a/comfy/sd.py b/comfy/sd.py index 2e1ae840..2a38ceb1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,5 +1,6 @@ import torch import contextlib +import copy import sd1_clip import sd2_clip @@ -274,12 +275,20 @@ class ModelPatcher: self.model = model self.patches = [] self.backup = {} + self.model_options = {"transformer_options":{}} def clone(self): n = ModelPatcher(self.model) n.patches = self.patches[:] + n.model_options = copy.deepcopy(self.model_options) return n + def set_model_tomesd(self, ratio): + self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} + + def model_dtype(self): + return self.model.diffusion_model.dtype + def add_patches(self, patches, strength=1.0): p = {} model_sd = self.model.state_dict() diff --git a/nodes.py b/nodes.py index 6fb7f017..e69832c5 100644 --- a/nodes.py +++ b/nodes.py @@ -254,6 +254,22 @@ class LoraLoader: model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) +class TomePatchModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, ratio): + m = model.clone() + m.set_model_tomesd(ratio) + return (m, ) + class VAELoader: @classmethod def INPUT_TYPES(s): @@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, model_management.load_controlnet_gpu(control_net_models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) else: #other samplers pass @@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = { "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, + "TomePatchModel": TomePatchModel, } def load_custom_node(module_path):