2025-06-13 04:05:23 -07:00
|
|
|
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from einops import rearrange
|
|
|
|
from einops.layers.torch import Rearrange
|
|
|
|
import logging
|
|
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import math
|
|
|
|
|
|
|
|
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(
|
|
|
|
t: torch.Tensor,
|
|
|
|
freqs: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
|
|
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
|
|
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
|
|
|
return t_out
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------- Feed Forward Network -----------------------
|
|
|
|
class GPT2FeedForward(nn.Module):
|
|
|
|
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.activation = nn.GELU()
|
|
|
|
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
|
|
|
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
self._layer_id = None
|
|
|
|
self._dim = d_model
|
|
|
|
self._hidden_dim = d_ff
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
x = self.layer1(x)
|
|
|
|
|
|
|
|
x = self.activation(x)
|
|
|
|
x = self.layer2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
|
|
|
"""Computes multi-head attention using PyTorch's native implementation.
|
|
|
|
|
|
|
|
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
|
|
|
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
|
|
|
attention, and rearranges the output back to the original format.
|
|
|
|
|
|
|
|
The input tensor names use the following dimension conventions:
|
|
|
|
|
|
|
|
- B: batch size
|
|
|
|
- S: sequence length
|
|
|
|
- H: number of attention heads
|
|
|
|
- D: head dimension
|
|
|
|
|
|
|
|
Args:
|
|
|
|
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
|
|
|
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
|
|
|
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
|
|
|
"""
|
|
|
|
in_q_shape = q_B_S_H_D.shape
|
|
|
|
in_k_shape = k_B_S_H_D.shape
|
|
|
|
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
|
|
|
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
|
|
|
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
2025-06-14 02:02:05 -07:00
|
|
|
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
2025-06-13 04:05:23 -07:00
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
"""
|
|
|
|
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
|
|
|
|
|
|
|
This module implements a multi-head attention layer that can operate in either self-attention
|
|
|
|
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
|
|
|
The implementation uses scaled dot-product attention and supports optional bias terms and
|
|
|
|
dropout regularization.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
query_dim (int): The dimensionality of the query vectors.
|
|
|
|
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
|
|
|
If None, the module operates in self-attention mode using query_dim. Default: None
|
|
|
|
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
|
|
|
head_dim (int, optional): The dimension of each attention head. Default: 64
|
|
|
|
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
|
|
|
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
|
|
|
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # Self-attention with 512 dimensions and 8 heads
|
|
|
|
>>> self_attn = Attention(query_dim=512)
|
|
|
|
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
|
|
|
>>> out = self_attn(x) # (32, 16, 512)
|
|
|
|
|
|
|
|
>>> # Cross-attention
|
|
|
|
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
|
|
|
>>> query = torch.randn(32, 16, 512)
|
|
|
|
>>> context = torch.randn(32, 8, 256)
|
|
|
|
>>> out = cross_attn(query, context) # (32, 16, 512)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
query_dim: int,
|
|
|
|
context_dim: Optional[int] = None,
|
|
|
|
n_heads: int = 8,
|
|
|
|
head_dim: int = 64,
|
|
|
|
dropout: float = 0.0,
|
|
|
|
device=None,
|
|
|
|
dtype=None,
|
|
|
|
operations=None,
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
logging.debug(
|
|
|
|
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
|
|
|
f"{n_heads} heads with a dimension of {head_dim}."
|
|
|
|
)
|
|
|
|
self.is_selfattn = context_dim is None # self attention
|
|
|
|
|
|
|
|
context_dim = query_dim if context_dim is None else context_dim
|
|
|
|
inner_dim = head_dim * n_heads
|
|
|
|
|
|
|
|
self.n_heads = n_heads
|
|
|
|
self.head_dim = head_dim
|
|
|
|
self.query_dim = query_dim
|
|
|
|
self.context_dim = context_dim
|
|
|
|
|
|
|
|
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
self.v_norm = nn.Identity()
|
|
|
|
|
|
|
|
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
|
|
|
|
|
|
|
self.attn_op = torch_attention_op
|
|
|
|
|
|
|
|
self._query_dim = query_dim
|
|
|
|
self._context_dim = context_dim
|
|
|
|
self._inner_dim = inner_dim
|
|
|
|
|
|
|
|
def compute_qkv(
|
|
|
|
self,
|
|
|
|
x: torch.Tensor,
|
|
|
|
context: Optional[torch.Tensor] = None,
|
|
|
|
rope_emb: Optional[torch.Tensor] = None,
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
q = self.q_proj(x)
|
|
|
|
context = x if context is None else context
|
|
|
|
k = self.k_proj(context)
|
|
|
|
v = self.v_proj(context)
|
|
|
|
q, k, v = map(
|
|
|
|
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
|
|
|
(q, k, v),
|
|
|
|
)
|
|
|
|
|
|
|
|
def apply_norm_and_rotary_pos_emb(
|
|
|
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
q = self.q_norm(q)
|
|
|
|
k = self.k_norm(k)
|
|
|
|
v = self.v_norm(v)
|
|
|
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
|
|
|
q = apply_rotary_pos_emb(q, rope_emb)
|
|
|
|
k = apply_rotary_pos_emb(k, rope_emb)
|
|
|
|
return q, k, v
|
|
|
|
|
|
|
|
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
|
|
|
|
|
|
|
return q, k, v
|
|
|
|
|
|
|
|
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
|
|
|
result = self.attn_op(q, k, v) # [B, S, H, D]
|
|
|
|
return self.output_dropout(self.output_proj(result))
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x: torch.Tensor,
|
|
|
|
context: Optional[torch.Tensor] = None,
|
|
|
|
rope_emb: Optional[torch.Tensor] = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
x (Tensor): The query tensor of shape [B, Mq, K]
|
|
|
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
|
|
|
"""
|
|
|
|
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
|
|
|
return self.compute_attention(q, k, v)
|
|
|
|
|
|
|
|
|
|
|
|
class Timesteps(nn.Module):
|
|
|
|
def __init__(self, num_channels: int):
|
|
|
|
super().__init__()
|
|
|
|
self.num_channels = num_channels
|
|
|
|
|
|
|
|
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
|
|
|
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
|
|
|
timesteps = timesteps_B_T.flatten().float()
|
|
|
|
half_dim = self.num_channels // 2
|
|
|
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
|
|
|
exponent = exponent / (half_dim - 0.0)
|
|
|
|
|
|
|
|
emb = torch.exp(exponent)
|
|
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
|
|
|
|
|
|
sin_emb = torch.sin(emb)
|
|
|
|
cos_emb = torch.cos(emb)
|
|
|
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
|
|
|
|
|
|
|
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
|
|
|
|
|
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
|
|
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
|
|
|
super().__init__()
|
|
|
|
logging.debug(
|
|
|
|
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
|
|
|
)
|
|
|
|
self.in_dim = in_features
|
|
|
|
self.out_dim = out_features
|
|
|
|
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
|
|
|
self.activation = nn.SiLU()
|
|
|
|
self.use_adaln_lora = use_adaln_lora
|
|
|
|
if use_adaln_lora:
|
|
|
|
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
|
|
|
else:
|
|
|
|
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
|
|
emb = self.linear_1(sample)
|
|
|
|
emb = self.activation(emb)
|
|
|
|
emb = self.linear_2(emb)
|
|
|
|
|
|
|
|
if self.use_adaln_lora:
|
|
|
|
adaln_lora_B_T_3D = emb
|
|
|
|
emb_B_T_D = sample
|
|
|
|
else:
|
|
|
|
adaln_lora_B_T_3D = None
|
|
|
|
emb_B_T_D = emb
|
|
|
|
|
|
|
|
return emb_B_T_D, adaln_lora_B_T_3D
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
"""
|
|
|
|
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
|
|
|
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
|
|
|
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
|
|
|
and embedding each patch into a vector of size `out_channels`.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- spatial_patch_size (int): The size of each spatial patch.
|
|
|
|
- temporal_patch_size (int): The size of each temporal patch.
|
|
|
|
- in_channels (int): Number of input channels. Default: 3.
|
|
|
|
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
|
|
|
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
spatial_patch_size: int,
|
|
|
|
temporal_patch_size: int,
|
|
|
|
in_channels: int = 3,
|
|
|
|
out_channels: int = 768,
|
|
|
|
device=None, dtype=None, operations=None
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.spatial_patch_size = spatial_patch_size
|
|
|
|
self.temporal_patch_size = temporal_patch_size
|
|
|
|
|
|
|
|
self.proj = nn.Sequential(
|
|
|
|
Rearrange(
|
|
|
|
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
|
|
|
r=temporal_patch_size,
|
|
|
|
m=spatial_patch_size,
|
|
|
|
n=spatial_patch_size,
|
|
|
|
),
|
|
|
|
operations.Linear(
|
|
|
|
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Forward pass of the PatchEmbed module.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
|
|
|
B is the batch size,
|
|
|
|
C is the number of channels,
|
|
|
|
T is the temporal dimension,
|
|
|
|
H is the height, and
|
|
|
|
W is the width of the input.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
|
|
|
"""
|
|
|
|
assert x.dim() == 5
|
|
|
|
_, _, T, H, W = x.shape
|
|
|
|
assert (
|
|
|
|
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
|
|
|
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
|
|
|
assert T % self.temporal_patch_size == 0
|
|
|
|
x = self.proj(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class FinalLayer(nn.Module):
|
|
|
|
"""
|
|
|
|
The final layer of video DiT.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
hidden_size: int,
|
|
|
|
spatial_patch_size: int,
|
|
|
|
temporal_patch_size: int,
|
|
|
|
out_channels: int,
|
|
|
|
use_adaln_lora: bool = False,
|
|
|
|
adaln_lora_dim: int = 256,
|
|
|
|
device=None, dtype=None, operations=None
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
|
|
self.linear = operations.Linear(
|
|
|
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
|
|
|
)
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.n_adaln_chunks = 2
|
|
|
|
self.use_adaln_lora = use_adaln_lora
|
|
|
|
self.adaln_lora_dim = adaln_lora_dim
|
|
|
|
if use_adaln_lora:
|
|
|
|
self.adaln_modulation = nn.Sequential(
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.adaln_modulation = nn.Sequential(
|
|
|
|
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x_B_T_H_W_D: torch.Tensor,
|
|
|
|
emb_B_T_D: torch.Tensor,
|
|
|
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
|
|
|
):
|
|
|
|
if self.use_adaln_lora:
|
|
|
|
assert adaln_lora_B_T_3D is not None
|
|
|
|
shift_B_T_D, scale_B_T_D = (
|
|
|
|
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
|
|
|
).chunk(2, dim=-1)
|
|
|
|
else:
|
|
|
|
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
|
|
|
|
|
|
|
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
|
|
|
scale_B_T_D, "b t d -> b t 1 1 d"
|
|
|
|
)
|
|
|
|
|
|
|
|
def _fn(
|
|
|
|
_x_B_T_H_W_D: torch.Tensor,
|
|
|
|
_norm_layer: nn.Module,
|
|
|
|
_scale_B_T_1_1_D: torch.Tensor,
|
|
|
|
_shift_B_T_1_1_D: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
|
|
|
|
|
|
|
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
|
|
|
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
|
|
|
return x_B_T_H_W_O
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
"""
|
|
|
|
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
|
|
|
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
x_dim (int): Dimension of input features
|
|
|
|
context_dim (int): Dimension of context features for cross-attention
|
|
|
|
num_heads (int): Number of attention heads
|
|
|
|
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
|
|
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
|
|
|
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
|
|
|
|
|
|
|
The block applies the following sequence:
|
|
|
|
1. Self-attention with AdaLN modulation
|
|
|
|
2. Cross-attention with AdaLN modulation
|
|
|
|
3. MLP with AdaLN modulation
|
|
|
|
|
|
|
|
Each component uses skip connections and layer normalization.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
x_dim: int,
|
|
|
|
context_dim: int,
|
|
|
|
num_heads: int,
|
|
|
|
mlp_ratio: float = 4.0,
|
|
|
|
use_adaln_lora: bool = False,
|
|
|
|
adaln_lora_dim: int = 256,
|
|
|
|
device=None,
|
|
|
|
dtype=None,
|
|
|
|
operations=None,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.x_dim = x_dim
|
|
|
|
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
|
|
|
|
|
|
|
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.cross_attn = Attention(
|
|
|
|
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
|
|
|
)
|
|
|
|
|
|
|
|
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
|
|
|
|
|
|
|
self.use_adaln_lora = use_adaln_lora
|
|
|
|
if self.use_adaln_lora:
|
|
|
|
self.adaln_modulation_self_attn = nn.Sequential(
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
)
|
|
|
|
self.adaln_modulation_cross_attn = nn.Sequential(
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
)
|
|
|
|
self.adaln_modulation_mlp = nn.Sequential(
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
|
|
|
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
|
|
|
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x_B_T_H_W_D: torch.Tensor,
|
|
|
|
emb_B_T_D: torch.Tensor,
|
|
|
|
crossattn_emb: torch.Tensor,
|
|
|
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
|
|
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
|
|
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
if extra_per_block_pos_emb is not None:
|
|
|
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
|
|
|
|
|
|
|
if self.use_adaln_lora:
|
|
|
|
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
|
|
|
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
|
|
|
).chunk(3, dim=-1)
|
|
|
|
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
|
|
|
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
|
|
|
).chunk(3, dim=-1)
|
|
|
|
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
|
|
|
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
|
|
|
).chunk(3, dim=-1)
|
|
|
|
else:
|
|
|
|
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
|
|
|
emb_B_T_D
|
|
|
|
).chunk(3, dim=-1)
|
|
|
|
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
|
|
|
emb_B_T_D
|
|
|
|
).chunk(3, dim=-1)
|
|
|
|
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
|
|
|
|
|
|
|
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
|
|
|
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
|
|
|
|
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
|
|
|
|
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
|
|
|
|
|
|
|
B, T, H, W, D = x_B_T_H_W_D.shape
|
|
|
|
|
|
|
|
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
|
|
|
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
|
|
|
|
|
|
|
normalized_x_B_T_H_W_D = _fn(
|
|
|
|
x_B_T_H_W_D,
|
|
|
|
self.layer_norm_self_attn,
|
|
|
|
scale_self_attn_B_T_1_1_D,
|
|
|
|
shift_self_attn_B_T_1_1_D,
|
|
|
|
)
|
|
|
|
result_B_T_H_W_D = rearrange(
|
|
|
|
self.self_attn(
|
|
|
|
# normalized_x_B_T_HW_D,
|
|
|
|
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
|
|
|
None,
|
|
|
|
rope_emb=rope_emb_L_1_1_D,
|
|
|
|
),
|
|
|
|
"b (t h w) d -> b t h w d",
|
|
|
|
t=T,
|
|
|
|
h=H,
|
|
|
|
w=W,
|
|
|
|
)
|
|
|
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
|
|
|
|
|
|
|
def _x_fn(
|
|
|
|
_x_B_T_H_W_D: torch.Tensor,
|
|
|
|
layer_norm_cross_attn: Callable,
|
|
|
|
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
|
|
|
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
_normalized_x_B_T_H_W_D = _fn(
|
|
|
|
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
|
|
|
)
|
|
|
|
_result_B_T_H_W_D = rearrange(
|
|
|
|
self.cross_attn(
|
|
|
|
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
|
|
|
crossattn_emb,
|
|
|
|
rope_emb=rope_emb_L_1_1_D,
|
|
|
|
),
|
|
|
|
"b (t h w) d -> b t h w d",
|
|
|
|
t=T,
|
|
|
|
h=H,
|
|
|
|
w=W,
|
|
|
|
)
|
|
|
|
return _result_B_T_H_W_D
|
|
|
|
|
|
|
|
result_B_T_H_W_D = _x_fn(
|
|
|
|
x_B_T_H_W_D,
|
|
|
|
self.layer_norm_cross_attn,
|
|
|
|
scale_cross_attn_B_T_1_1_D,
|
|
|
|
shift_cross_attn_B_T_1_1_D,
|
|
|
|
)
|
|
|
|
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
|
|
|
|
|
|
|
normalized_x_B_T_H_W_D = _fn(
|
|
|
|
x_B_T_H_W_D,
|
|
|
|
self.layer_norm_mlp,
|
|
|
|
scale_mlp_B_T_1_1_D,
|
|
|
|
shift_mlp_B_T_1_1_D,
|
|
|
|
)
|
|
|
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
|
|
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
|
|
|
return x_B_T_H_W_D
|
|
|
|
|
|
|
|
|
|
|
|
class MiniTrainDIT(nn.Module):
|
|
|
|
"""
|
|
|
|
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
|
|
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
max_img_h (int): Maximum height of the input images.
|
|
|
|
max_img_w (int): Maximum width of the input images.
|
|
|
|
max_frames (int): Maximum number of frames in the video sequence.
|
|
|
|
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
|
|
|
out_channels (int): Number of output channels.
|
|
|
|
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
|
|
|
patch_temporal (int): Temporal resolution of patches for input processing.
|
|
|
|
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
|
|
|
model_channels (int): Base number of channels used throughout the model.
|
|
|
|
num_blocks (int): Number of transformer blocks.
|
|
|
|
num_heads (int): Number of heads in the multi-head attention layers.
|
|
|
|
mlp_ratio (float): Expansion ratio for MLP blocks.
|
|
|
|
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
|
|
|
pos_emb_cls (str): Type of positional embeddings.
|
|
|
|
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
|
|
|
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
|
|
|
min_fps (int): Minimum frames per second.
|
|
|
|
max_fps (int): Maximum frames per second.
|
|
|
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
|
|
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
|
|
|
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
|
|
|
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
|
|
|
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
|
|
|
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
|
|
|
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
|
|
|
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
|
|
|
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
max_img_h: int,
|
|
|
|
max_img_w: int,
|
|
|
|
max_frames: int,
|
|
|
|
in_channels: int,
|
|
|
|
out_channels: int,
|
|
|
|
patch_spatial: int, # tuple,
|
|
|
|
patch_temporal: int,
|
|
|
|
concat_padding_mask: bool = True,
|
|
|
|
# attention settings
|
|
|
|
model_channels: int = 768,
|
|
|
|
num_blocks: int = 10,
|
|
|
|
num_heads: int = 16,
|
|
|
|
mlp_ratio: float = 4.0,
|
|
|
|
# cross attention settings
|
|
|
|
crossattn_emb_channels: int = 1024,
|
|
|
|
# positional embedding settings
|
|
|
|
pos_emb_cls: str = "sincos",
|
|
|
|
pos_emb_learnable: bool = False,
|
|
|
|
pos_emb_interpolation: str = "crop",
|
|
|
|
min_fps: int = 1,
|
|
|
|
max_fps: int = 30,
|
|
|
|
use_adaln_lora: bool = False,
|
|
|
|
adaln_lora_dim: int = 256,
|
|
|
|
rope_h_extrapolation_ratio: float = 1.0,
|
|
|
|
rope_w_extrapolation_ratio: float = 1.0,
|
|
|
|
rope_t_extrapolation_ratio: float = 1.0,
|
|
|
|
extra_per_block_abs_pos_emb: bool = False,
|
|
|
|
extra_h_extrapolation_ratio: float = 1.0,
|
|
|
|
extra_w_extrapolation_ratio: float = 1.0,
|
|
|
|
extra_t_extrapolation_ratio: float = 1.0,
|
|
|
|
rope_enable_fps_modulation: bool = True,
|
|
|
|
image_model=None,
|
|
|
|
device=None,
|
|
|
|
dtype=None,
|
|
|
|
operations=None,
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.dtype = dtype
|
|
|
|
self.max_img_h = max_img_h
|
|
|
|
self.max_img_w = max_img_w
|
|
|
|
self.max_frames = max_frames
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.patch_spatial = patch_spatial
|
|
|
|
self.patch_temporal = patch_temporal
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.num_blocks = num_blocks
|
|
|
|
self.model_channels = model_channels
|
|
|
|
self.concat_padding_mask = concat_padding_mask
|
|
|
|
# positional embedding settings
|
|
|
|
self.pos_emb_cls = pos_emb_cls
|
|
|
|
self.pos_emb_learnable = pos_emb_learnable
|
|
|
|
self.pos_emb_interpolation = pos_emb_interpolation
|
|
|
|
self.min_fps = min_fps
|
|
|
|
self.max_fps = max_fps
|
|
|
|
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
|
|
|
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
|
|
|
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
|
|
|
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
|
|
|
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
|
|
|
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
|
|
|
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
|
|
|
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
|
|
|
|
|
|
|
self.build_pos_embed(device=device, dtype=dtype)
|
|
|
|
self.use_adaln_lora = use_adaln_lora
|
|
|
|
self.adaln_lora_dim = adaln_lora_dim
|
|
|
|
self.t_embedder = nn.Sequential(
|
|
|
|
Timesteps(model_channels),
|
|
|
|
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
|
|
|
)
|
|
|
|
|
|
|
|
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
|
|
|
self.x_embedder = PatchEmbed(
|
|
|
|
spatial_patch_size=patch_spatial,
|
|
|
|
temporal_patch_size=patch_temporal,
|
|
|
|
in_channels=in_channels,
|
|
|
|
out_channels=model_channels,
|
|
|
|
device=device, dtype=dtype, operations=operations,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList(
|
|
|
|
[
|
|
|
|
Block(
|
|
|
|
x_dim=model_channels,
|
|
|
|
context_dim=crossattn_emb_channels,
|
|
|
|
num_heads=num_heads,
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
use_adaln_lora=use_adaln_lora,
|
|
|
|
adaln_lora_dim=adaln_lora_dim,
|
|
|
|
device=device, dtype=dtype, operations=operations,
|
|
|
|
)
|
|
|
|
for _ in range(num_blocks)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.final_layer = FinalLayer(
|
|
|
|
hidden_size=self.model_channels,
|
|
|
|
spatial_patch_size=self.patch_spatial,
|
|
|
|
temporal_patch_size=self.patch_temporal,
|
|
|
|
out_channels=self.out_channels,
|
|
|
|
use_adaln_lora=self.use_adaln_lora,
|
|
|
|
adaln_lora_dim=self.adaln_lora_dim,
|
|
|
|
device=device, dtype=dtype, operations=operations,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
def build_pos_embed(self, device=None, dtype=None) -> None:
|
|
|
|
if self.pos_emb_cls == "rope3d":
|
|
|
|
cls_type = VideoRopePosition3DEmb
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
|
|
|
|
|
|
|
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
|
|
|
kwargs = dict(
|
|
|
|
model_channels=self.model_channels,
|
|
|
|
len_h=self.max_img_h // self.patch_spatial,
|
|
|
|
len_w=self.max_img_w // self.patch_spatial,
|
|
|
|
len_t=self.max_frames // self.patch_temporal,
|
|
|
|
max_fps=self.max_fps,
|
|
|
|
min_fps=self.min_fps,
|
|
|
|
is_learnable=self.pos_emb_learnable,
|
|
|
|
interpolation=self.pos_emb_interpolation,
|
|
|
|
head_dim=self.model_channels // self.num_heads,
|
|
|
|
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
|
|
|
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
|
|
|
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
|
|
|
enable_fps_modulation=self.rope_enable_fps_modulation,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
self.pos_embedder = cls_type(
|
|
|
|
**kwargs, # type: ignore
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.extra_per_block_abs_pos_emb:
|
|
|
|
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
|
|
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
|
|
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
|
|
|
kwargs["device"] = device
|
|
|
|
kwargs["dtype"] = dtype
|
|
|
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
|
|
|
**kwargs, # type: ignore
|
|
|
|
)
|
|
|
|
|
|
|
|
def prepare_embedded_sequence(
|
|
|
|
self,
|
|
|
|
x_B_C_T_H_W: torch.Tensor,
|
|
|
|
fps: Optional[torch.Tensor] = None,
|
|
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
|
|
"""
|
|
|
|
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x_B_C_T_H_W (torch.Tensor): video
|
|
|
|
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
|
|
|
If None, a default value (`self.base_fps`) will be used.
|
|
|
|
padding_mask (Optional[torch.Tensor]): current it is not used
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
|
|
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
|
|
|
- An optional positional embedding tensor, returned only if the positional embedding class
|
|
|
|
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
|
|
|
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
|
|
|
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
|
|
|
the `self.pos_embedder` with the shape [T, H, W].
|
|
|
|
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
|
|
|
`self.pos_embedder` with the fps tensor.
|
|
|
|
- Otherwise, the positional embeddings are generated without considering fps.
|
|
|
|
"""
|
|
|
|
if self.concat_padding_mask:
|
|
|
|
if padding_mask is None:
|
|
|
|
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
|
|
|
else:
|
|
|
|
padding_mask = transforms.functional.resize(
|
|
|
|
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
|
|
|
)
|
|
|
|
x_B_C_T_H_W = torch.cat(
|
|
|
|
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
|
|
|
)
|
|
|
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
|
|
|
|
|
|
|
if self.extra_per_block_abs_pos_emb:
|
|
|
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
|
|
|
else:
|
|
|
|
extra_pos_emb = None
|
|
|
|
|
|
|
|
if "rope" in self.pos_emb_cls.lower():
|
|
|
|
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
|
|
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
|
|
|
|
|
|
|
return x_B_T_H_W_D, None, extra_pos_emb
|
|
|
|
|
|
|
|
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
|
|
|
x_B_C_Tt_Hp_Wp = rearrange(
|
|
|
|
x_B_T_H_W_M,
|
|
|
|
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
|
|
|
p1=self.patch_spatial,
|
|
|
|
p2=self.patch_spatial,
|
|
|
|
t=self.patch_temporal,
|
|
|
|
)
|
|
|
|
return x_B_C_Tt_Hp_Wp
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x: torch.Tensor,
|
|
|
|
timesteps: torch.Tensor,
|
|
|
|
context: torch.Tensor,
|
|
|
|
fps: Optional[torch.Tensor] = None,
|
|
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
x_B_C_T_H_W = x
|
|
|
|
timesteps_B_T = timesteps
|
|
|
|
crossattn_emb = context
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
|
|
|
timesteps: (B, ) tensor of timesteps
|
|
|
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
|
|
|
"""
|
|
|
|
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
|
|
|
x_B_C_T_H_W,
|
|
|
|
fps=fps,
|
|
|
|
padding_mask=padding_mask,
|
|
|
|
)
|
|
|
|
|
|
|
|
if timesteps_B_T.ndim == 1:
|
|
|
|
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
|
|
|
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
|
|
|
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
|
|
|
|
|
|
|
# for logging purpose
|
|
|
|
affline_scale_log_info = {}
|
|
|
|
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
|
|
|
self.affline_scale_log_info = affline_scale_log_info
|
|
|
|
self.affline_emb = t_embedding_B_T_D
|
|
|
|
self.crossattn_emb = crossattn_emb
|
|
|
|
|
|
|
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
|
|
|
assert (
|
|
|
|
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
|
|
|
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
|
|
|
|
|
|
|
block_kwargs = {
|
|
|
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
|
|
|
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
|
|
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
|
|
|
}
|
|
|
|
for block in self.blocks:
|
|
|
|
x_B_T_H_W_D = block(
|
|
|
|
x_B_T_H_W_D,
|
|
|
|
t_embedding_B_T_D,
|
|
|
|
crossattn_emb,
|
|
|
|
**block_kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
|
|
|
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
|
|
|
return x_B_C_Tt_Hp_Wp
|