diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f9fd16d8..199a3641 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -216,3 +216,7 @@ class Mochi(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + +class LTXV(LatentFormat): + latent_channels = 128 + diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py new file mode 100644 index 00000000..2792384d --- /dev/null +++ b/comfy/ldm/lightricks/model.py @@ -0,0 +1,502 @@ +import torch +from torch import nn +import comfy.ldm.modules.attention +from comfy.ldm.genmo.joint_model.layers import RMSNorm +import comfy.ldm.common_dit +from einops import rearrange +import math +from typing import Dict, Optional, Tuple + +from .symmetric_patchifier import SymmetricPatchifier + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + dtype=None, device=None, operations=None, + ): + super().__init__() + + self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device) + + if cond_proj_dim is not None: + self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device) + else: + self.cond_proj = None + + self.act = nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + + if post_act_fn is None: + self.post_act = None + # else: + # self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + ) + + self.silu = nn.SiLU() + self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class GELU_approx(nn.Module): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device) + + def forward(self, x): + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + super().__init__() + inner_dim = int(dim * mult) + project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + ) + + def forward(self, x): + return self.net(x) + + +def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + super().__init__() + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + self.attn_precision = attn_precision + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device) + self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device) + + self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + + self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None, pe=None): + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe) + + if mask is None: + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + else: + out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + super().__init__() + + self.attn_precision = attn_precision + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) + + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + + self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + + x += self.attn2(x, context=context, mask=attention_mask) + + y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp + x += self.ff(y) * gate_mlp + + return x + +def get_fractional_positions(indices_grid, max_pos): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + +def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): + dtype = torch.float32 #self.dtype + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + + start = 1 + end = theta + device = fractional_positions.device + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + + indices = indices * math.pi / 2 + + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + + +class LTXVModel(torch.nn.Module): + def __init__(self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + + caption_channels=4096, + num_layers=28, + + + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + ) + + # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) + + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + # attn_precision=attn_precision, + dtype=dtype, device=device, operations=operations + ) + for d in range(num_layers) + ] + ) + + self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) + self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + + self.patchifier = SymmetricPatchifier(1) + + def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): + indices_grid = self.patchifier.get_grid( + orig_num_frames=x.shape[2], + orig_height=x.shape[3], + orig_width=x.shape[4], + batch_size=x.shape[0], + scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate + device=x.device, + ) + + if guiding_latent is not None: + ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) + input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) + ts *= input_ts + ts[:, :, 0] = 0.0 + timestep = self.patchifier.patchify(ts) + input_x = x.clone() + x[:, :, 0] = guiding_latent[:, :, 0] + + orig_shape = list(x.shape) + + x = self.patchifier.patchify(x) + + x = self.patchify_proj(x) + timestep = timestep * 1000.0 + + attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) + attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this + # attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype) + + pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) + + batch_size = x.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=x.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view( + batch_size, -1, x.shape[-1] + ) + + for block in self.transformer_blocks: + x = block( + x, + context=context, + attention_mask=attention_mask, + timestep=timestep, + pe=pe + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) + # Modulation + x = x * (1 + scale) + shift + x = self.proj_out(x) + + x = self.patchifier.unpatchify( + latents=x, + output_height=orig_shape[3], + output_width=orig_shape[4], + output_num_frames=orig_shape[2], + out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), + ) + + if guiding_latent is not None: + x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0] + + # print("res", x) + return x diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py new file mode 100644 index 00000000..51ce5058 --- /dev/null +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -0,0 +1,105 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from einops import rearrange +from torch import Tensor + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Patchifier(ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify( + self, latents: Tensor, frame_rates: Tensor, scale_grid: bool + ) -> Tuple[Tensor, Tensor]: + pass + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + output_num_frames: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_grid( + self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device + ): + f = orig_num_frames // self._patch_size[0] + h = orig_height // self._patch_size[1] + w = orig_width // self._patch_size[2] + grid_h = torch.arange(h, dtype=torch.float32, device=device) + grid_w = torch.arange(w, dtype=torch.float32, device=device) + grid_f = torch.arange(f, dtype=torch.float32, device=device) + grid = torch.meshgrid(grid_f, grid_h, grid_w) + grid = torch.stack(grid, dim=0) + grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + if scale_grid is not None: + for i in range(3): + if isinstance(scale_grid[i], Tensor): + scale = append_dims(scale_grid[i], grid.ndim - 1) + else: + scale = scale_grid[i] + grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] + + grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) + return grid + + +class SymmetricPatchifier(Patchifier): + def patchify( + self, + latents: Tensor, + ) -> Tuple[Tensor, Tensor]: + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + output_num_frames: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q) ", + f=output_num_frames, + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py new file mode 100644 index 00000000..146dea19 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -0,0 +1,62 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode="zeros", + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py new file mode 100644 index 00000000..4138fdf3 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -0,0 +1,698 @@ +import torch +from torch import nn +from functools import partial +import math +from einops import rearrange +from typing import Any, Mapping, Optional, Tuple, Union, List +from .conv_nd_factory import make_conv_nd, make_linear_nd +from .pixel_norm import PixelNorm + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks=[("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, output_channel, conv_out_channels, 3, padding=1, causal=True + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks=[("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, in_channels=input_channel, stride=(2, 1, 1) + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, in_channels=input_channel, stride=(1, 2, 2) + ) + elif block_name == "compress_all": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, output_channel, out_channels, 3, padding=1, causal=True + ) + + self.gradient_checkpointing = False + + # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + # assert target_shape is not None, "target_shape must be provided" + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, hidden_states: torch.FloatTensor, causal: bool = True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states, causal=causal) + + return hidden_states + + +class DepthToSpaceUpsample(nn.Module): + def __init__(self, dims, in_channels, stride, residual=False): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + ) + self.residual = residual + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + input_tensor = self.norm3(input_tensor) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + self.register_buffer("mean-of-stds", torch.empty(128)) + self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) + self.register_buffer("channel", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) + +class VideoVAE(nn.Module): + def __init__(self): + super().__init__() + config = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + } + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + + self.encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + ) + + self.decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + ) + + self.per_channel_statistics = processor() + + def encode(self, x): + means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + return self.per_channel_statistics.normalize(means) + + def decode(self, x): + return self.decoder(self.per_channel_statistics.un_normalize(x)) + diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py new file mode 100644 index 00000000..389f8165 --- /dev/null +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -0,0 +1,82 @@ +from typing import Tuple, Union + +import torch + +from .dual_conv3d import DualConv3d +from .causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/comfy/ldm/lightricks/vae/dual_conv3d.py b/comfy/ldm/lightricks/vae/dual_conv3d.py new file mode 100644 index 00000000..6bd54c0a --- /dev/null +++ b/comfy/ldm/lightricks/vae/dual_conv3d.py @@ -0,0 +1,195 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/comfy/ldm/lightricks/vae/pixel_norm.py b/comfy/ldm/lightricks/vae/pixel_norm.py new file mode 100644 index 00000000..9bc3ea60 --- /dev/null +++ b/comfy/ldm/lightricks/vae/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6f1aa570..e7bfc8d7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -30,6 +30,7 @@ import comfy.ldm.hydit.models import comfy.ldm.audio.dit import comfy.ldm.audio.embedders import comfy.ldm.flux.model +import comfy.ldm.lightricks.model import comfy.model_management import comfy.conds @@ -141,6 +142,7 @@ class BaseModel(torch.nn.Module): extra = extra.to(dtype) extra_conds[o] = extra + print(t) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) @@ -767,3 +769,23 @@ class GenmoMochi(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class LTXV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guiding_latent = kwargs.get("guiding_latent", None) + if guiding_latent is not None: + out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 008e4b19..c5c5dbb2 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -183,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix): dit_config["rope_theta"] = 10000.0 return dit_config + if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv + dit_config = {} + dit_config["image_model"] = "ltxv" + return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index 3dea043f..b07b5fe3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -8,6 +8,7 @@ from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model +import comfy.ldm.lightricks.vae.causal_video_autoencoder import yaml import comfy.utils @@ -27,6 +28,7 @@ import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.long_clipl import comfy.text_encoders.genmo +import comfy.text_encoders.lt import comfy.model_patcher import comfy.lora @@ -261,6 +263,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) self.working_dtypes = [torch.float16, torch.float32] + elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv + self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE() + self.latent_channels = 128 + self.latent_dim = 3 + self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.upscale_ratio = 8 + self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -452,6 +462,7 @@ class CLIPType(Enum): HUNYUAN_DIT = 5 FLUX = 6 MOCHI = 7 + LTXV = 8 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] @@ -530,6 +541,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 75ddaee5..caa1a184 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.genmo +import comfy.text_encoders.lt from . import supported_models_base from . import latent_formats @@ -702,7 +703,34 @@ class GenmoMochi(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect)) +class LTXV(supported_models_base.BASE): + unet_config = { + "image_model": "ltxv", + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi] + sampling_settings = { + "shift": 2.37, + } + + unet_extra_config = {} + latent_format = latent_formats.LTXV + + memory_usage_factor = 2.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXV(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV] models += [SVD_img2vid] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py new file mode 100644 index 00000000..5c2ce583 --- /dev/null +++ b/comfy/text_encoders/lt.py @@ -0,0 +1,18 @@ +from comfy import sd1_clip +import os +from transformers import T5TokenizerFast +import comfy.text_encoders.genmo + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128? + + +class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + + +def ltxv_te(*args, **kwargs): + return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py new file mode 100644 index 00000000..17177b66 --- /dev/null +++ b/comfy_extras/nodes_lt.py @@ -0,0 +1,183 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.model_sampling +import math + +class EmptyLTXVLatentVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/video/ltxv" + + def generate(self, width, height, length, batch_size=1): + latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return ({"samples": latent}, ) + + +class LTXVImgToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE",), + "image": ("IMAGE",), + "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + CATEGORY = "conditioning/video_models" + FUNCTION = "generate" + + def generate(self, positive, negative, image, vae, width, height, length, batch_size): + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t}) + negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t}) + + latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + latent[:, :, :t.shape[2]] = t + return (positive, negative, {"samples": latent}, ) + + +class LTXVConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "append" + + CATEGORY = "conditioning/video_models" + + def append(self, positive, negative, frame_rate): + positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) + negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) + return (positive, negative) + + +class ModelSamplingLTXV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), + }, + "optional": {"latent": ("LATENT",), } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, max_shift, base_shift, latent=None): + m = model.clone() + + if latent is None: + tokens = 4096 + else: + tokens = math.prod(latent["samples"].shape[2:]) + + x1 = 1024 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (tokens) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + + +class LTXVScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), + "stretch": ("BOOLEAN", { + "default": True, + "tooltip": "Stretch the sigmas to be in the range [terminal, 1]." + }), + "terminal": ( + "FLOAT", + { + "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, + "tooltip": "The terminal value of the sigmas after stretching." + }, + ), + }, + "optional": {"latent": ("LATENT",), } + } + + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None): + if latent is None: + tokens = 4096 + else: + tokens = math.prod(latent["samples"].shape[2:]) + + sigmas = torch.linspace(1.0, 0.0, steps + 1) + + x1 = 1024 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + sigma_shift = (tokens) * mm + b + print(sigma_shift) + + power = 1 + sigmas = torch.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) + + # Stretch sigmas so that its final value matches the given terminal value. + if stretch: + non_zero_mask = sigmas != 0 + non_zero_sigmas = sigmas[non_zero_mask] + one_minus_z = 1.0 - non_zero_sigmas + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched = 1.0 - (one_minus_z / scale_factor) + sigmas[non_zero_mask] = stretched + + print(sigmas) + return (sigmas,) + + +NODE_CLASS_MAPPINGS = { + "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, + "LTXVImgToVideo": LTXVImgToVideo, + "ModelSamplingLTXV": ModelSamplingLTXV, + "LTXVConditioning": LTXVConditioning, + "LTXVScheduler": LTXVScheduler, +} diff --git a/nodes.py b/nodes.py index 7f151586..01af6c68 100644 --- a/nodes.py +++ b/nodes.py @@ -897,7 +897,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -915,6 +915,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.STABLE_AUDIO elif type == "mochi": clip_type = comfy.sd.CLIPType.MOCHI + elif type == "ltxv": + clip_type = comfy.sd.CLIPType.LTXV else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION @@ -2136,6 +2138,7 @@ def init_builtin_extra_nodes(): "nodes_torch_compile.py", "nodes_mochi.py", "nodes_slg.py", + "nodes_lt.py", ] import_failed = []