diff --git a/comfy/ldm/genmo/vae/model.py b/comfy/ldm/genmo/vae/model.py index e44c08a4..b68d48ae 100644 --- a/comfy/ldm/genmo/vae/model.py +++ b/comfy/ldm/genmo/vae/model.py @@ -2,12 +2,16 @@ #adapted to ComfyUI from typing import Callable, List, Optional, Tuple, Union +from functools import partial +import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from comfy.ldm.modules.attention import optimized_attention + import comfy.ops ops = comfy.ops.disable_weight_init @@ -158,8 +162,10 @@ class ResBlock(nn.Module): *, affine: bool = True, attn_block: Optional[nn.Module] = None, - padding_mode: str = "replicate", causal: bool = True, + prune_bottleneck: bool = False, + padding_mode: str, + bias: bool = True, ): super().__init__() self.channels = channels @@ -170,23 +176,23 @@ class ResBlock(nn.Module): nn.SiLU(inplace=True), PConv3d( in_channels=channels, - out_channels=channels, + out_channels=channels // 2 if prune_bottleneck else channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding_mode=padding_mode, - bias=True, - # causal=causal, + bias=bias, + causal=causal, ), norm_fn(channels, affine=affine), nn.SiLU(inplace=True), PConv3d( - in_channels=channels, + in_channels=channels // 2 if prune_bottleneck else channels, out_channels=channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding_mode=padding_mode, - bias=True, - # causal=causal, + bias=bias, + causal=causal, ), ) @@ -206,6 +212,81 @@ class ResBlock(nn.Module): return self.attn_block(x) +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + out_bias: bool = True, + qk_norm: bool = True, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.qk_norm = qk_norm + + self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) + self.out = nn.Linear(dim, dim, bias=out_bias) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Compute temporal self-attention. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + chunk_size: Chunk size for large tensors. + + Returns: + x: Output tensor. Shape: [B, C, T, H, W]. + """ + B, _, T, H, W = x.shape + + if T == 1: + # No attention for single frame. + x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C] + qkv = self.qkv(x) + _, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys. + x = self.out(x) + return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W] + + # 1D temporal attention. + x = rearrange(x, "B C t h w -> (B h w) t C") + qkv = self.qkv(x) + + # Input: qkv with shape [B, t, 3 * num_heads * head_dim] + # Output: x with shape [B, num_heads, t, head_dim] + q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2) + + if self.qk_norm: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + + x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + + assert x.size(0) == q.size(0) + + x = self.out(x) + x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + **attn_kwargs, + ) -> None: + super().__init__() + self.norm = norm_fn(dim) + self.attn = Attention(dim, **attn_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.attn(self.norm(x)) + + class CausalUpsampleBlock(nn.Module): def __init__( self, @@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module): return x -def block_fn(channels, *, has_attention: bool = False, **block_kwargs): - assert has_attention is False #NOTE: if this is ever true add back the attention code. - - attn_block = None #AttentionBlock(channels) if has_attention else None - - return ResBlock( - channels, affine=True, attn_block=attn_block, **block_kwargs - ) +def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs): + attn_block = AttentionBlock(channels) if has_attention else None + return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs) class DownsampleBlock(nn.Module): @@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module): out_channels=out_channels, kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), stride=(temporal_reduction, spatial_reduction, spatial_reduction), + # First layer in each block always uses replicate padding padding_mode="replicate", - bias=True, + bias=block_kwargs["bias"], ) ) @@ -382,7 +459,7 @@ class Decoder(nn.Module): blocks = [] first_block = [ - nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) + ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) ] # Input layer. # First set of blocks preserve channel count. for _ in range(num_res_blocks[-1]): @@ -452,11 +529,165 @@ class Decoder(nn.Module): return self.output_proj(x).contiguous() +class LatentDistribution: + def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): + """Initialize latent distribution. + + Args: + mean: Mean of the distribution. Shape: [B, C, T, H, W]. + logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W]. + """ + assert mean.shape == logvar.shape + self.mean = mean + self.logvar = logvar + + def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None): + if temperature == 0.0: + return self.mean + + if noise is None: + noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator) + else: + assert noise.device == self.mean.device + noise = noise.to(self.mean.dtype) + + if temperature != 1.0: + raise NotImplementedError(f"Temperature {temperature} is not supported.") + + # Just Gaussian sample with no scaling of variance. + return noise * torch.exp(self.logvar * 0.5) + self.mean + + def mode(self): + return self.mean + +class Encoder(nn.Module): + def __init__( + self, + *, + in_channels: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + latent_dim: int, + temporal_reductions: List[int], + spatial_reductions: List[int], + prune_bottlenecks: List[bool], + has_attentions: List[bool], + affine: bool = True, + bias: bool = True, + input_is_conv_1x1: bool = False, + padding_mode: str, + ): + super().__init__() + self.temporal_reductions = temporal_reductions + self.spatial_reductions = spatial_reductions + self.base_channels = base_channels + self.channel_multipliers = channel_multipliers + self.num_res_blocks = num_res_blocks + self.latent_dim = latent_dim + + self.fourier_features = FourierFeatures() + ch = [mult * base_channels for mult in channel_multipliers] + num_down_blocks = len(ch) - 1 + assert len(num_res_blocks) == num_down_blocks + 2 + + layers = ( + [ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)] + if not input_is_conv_1x1 + else [Conv1x1(in_channels, ch[0])] + ) + + assert len(prune_bottlenecks) == num_down_blocks + 2 + assert len(has_attentions) == num_down_blocks + 2 + block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias) + + for _ in range(num_res_blocks[0]): + layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0])) + prune_bottlenecks = prune_bottlenecks[1:] + has_attentions = has_attentions[1:] + + assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1 + for i in range(num_down_blocks): + layer = DownsampleBlock( + ch[i], + ch[i + 1], + num_res_blocks=num_res_blocks[i + 1], + temporal_reduction=temporal_reductions[i], + spatial_reduction=spatial_reductions[i], + prune_bottleneck=prune_bottlenecks[i], + has_attention=has_attentions[i], + affine=affine, + bias=bias, + padding_mode=padding_mode, + ) + + layers.append(layer) + + # Additional blocks. + for _ in range(num_res_blocks[-1]): + layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1])) + + self.layers = nn.Sequential(*layers) + + # Output layers. + self.output_norm = norm_fn(ch[-1]) + self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False) + + @property + def temporal_downsample(self): + return math.prod(self.temporal_reductions) + + @property + def spatial_downsample(self): + return math.prod(self.spatial_reductions) + + def forward(self, x) -> LatentDistribution: + """Forward pass. + + Args: + x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1] + + Returns: + means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1]. + h = H // 8, w = W // 8, t - 1 = (T - 1) // 6 + logvar: Shape: [B, latent_dim, t, h, w]. + """ + assert x.ndim == 5, f"Expected 5D input, got {x.shape}" + x = self.fourier_features(x) + + x = self.layers(x) + + x = self.output_norm(x) + x = F.silu(x, inplace=True) + x = self.output_proj(x) + + means, logvar = torch.chunk(x, 2, dim=1) + + assert means.ndim == 5 + assert logvar.shape == means.shape + assert means.size(1) == self.latent_dim + + return LatentDistribution(means, logvar) + class VideoVAE(nn.Module): def __init__(self): super().__init__() - self.encoder = None #TODO once the model releases + self.encoder = Encoder( + in_channels=15, + base_channels=64, + channel_multipliers=[1, 2, 4, 6], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + temporal_reductions=[1, 2, 3], + spatial_reductions=[2, 2, 2], + prune_bottlenecks=[False, False, False, False, False], + has_attentions=[False, True, True, True, True], + affine=True, + bias=True, + input_is_conv_1x1=True, + padding_mode="replicate" + ) self.decoder = Decoder( out_channels=3, base_channels=128, @@ -474,7 +705,7 @@ class VideoVAE(nn.Module): ) def encode(self, x): - return self.encoder(x) + return self.encoder(x).mode() def decode(self, x): return self.decoder(x) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 85ea406e..3f543abd 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh return out +if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue. + SDP_BATCH_LIMIT = 2**15 +else: + #TODO: other GPUs ? + SDP_BATCH_LIMIT = 2**31 + + def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) - ) + if SDP_BATCH_LIMIT >= q.shape[0]: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + else: + out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) + for i in range(0, q.shape[0], SDP_BATCH_LIMIT): + out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out diff --git a/comfy/sd.py b/comfy/sd.py index e9a3579c..2f4c01a1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -171,6 +171,7 @@ class VAE: self.downscale_ratio = 8 self.upscale_ratio = 8 self.latent_channels = 4 + self.latent_dim = 2 self.output_channels = 3 self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) @@ -240,16 +241,22 @@ class VAE: self.output_channels = 2 self.upscale_ratio = 2048 self.downscale_ratio = 2048 + self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae + elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae if "blocks.2.blocks.3.stack.5.weight" in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) + if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."}) self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE() self.latent_channels = 12 + self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) + self.working_dtypes = [torch.float16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -361,17 +368,22 @@ class VAE: def encode(self, pixel_samples): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) - pixel_samples = pixel_samples.movedim(-1,1) + pixel_samples = pixel_samples.movedim(-1, 1) + if self.latent_dim == 3: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device) + samples = None for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) + out = self.first_stage_model.encode(pixels_in).to(self.output_device).float() + if samples is None: + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) + samples[x:x + batch_number] = out except model_management.OOM_EXCEPTION as e: logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")