2025-02-04 08:56:00 +00:00
|
|
|
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
2025-02-04 09:20:56 +00:00
|
|
|
from __future__ import annotations
|
2025-02-04 08:56:00 +00:00
|
|
|
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
|
|
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
2025-02-05 09:16:59 +00:00
|
|
|
from comfy.ldm.flux.layers import EmbedND
|
2025-02-04 08:56:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def modulate(x, scale):
|
|
|
|
return x * (1 + scale.unsqueeze(1))
|
|
|
|
|
|
|
|
#############################################################################
|
|
|
|
# Core NextDiT Model #
|
|
|
|
#############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
class JointAttention(nn.Module):
|
|
|
|
"""Multi-head attention module."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dim: int,
|
|
|
|
n_heads: int,
|
|
|
|
n_kv_heads: Optional[int],
|
|
|
|
qk_norm: bool,
|
|
|
|
operation_settings={},
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Initialize the Attention module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (int): Number of input dimensions.
|
|
|
|
n_heads (int): Number of heads.
|
|
|
|
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
|
|
|
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
|
|
|
self.n_local_heads = n_heads
|
|
|
|
self.n_local_kv_heads = self.n_kv_heads
|
|
|
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
|
|
self.head_dim = dim // n_heads
|
|
|
|
|
|
|
|
self.qkv = operation_settings.get("operations").Linear(
|
|
|
|
dim,
|
|
|
|
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
|
|
|
|
bias=False,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
self.out = operation_settings.get("operations").Linear(
|
|
|
|
n_heads * self.head_dim,
|
|
|
|
dim,
|
|
|
|
bias=False,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
|
|
|
|
if qk_norm:
|
|
|
|
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
|
|
|
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
|
|
|
else:
|
|
|
|
self.q_norm = self.k_norm = nn.Identity()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def apply_rotary_emb(
|
|
|
|
x_in: torch.Tensor,
|
|
|
|
freqs_cis: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Apply rotary embeddings to input tensors using the given frequency
|
|
|
|
tensor.
|
|
|
|
|
|
|
|
This function applies rotary embeddings to the given query 'xq' and
|
|
|
|
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
|
|
|
input tensors are reshaped as complex numbers, and the frequency tensor
|
|
|
|
is reshaped for broadcasting compatibility. The resulting tensors
|
|
|
|
contain rotary embeddings and are returned as real tensors.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
|
|
|
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
|
|
|
exponentials.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
|
|
|
and key tensor with rotary embeddings.
|
|
|
|
"""
|
|
|
|
|
2025-02-05 09:16:59 +00:00
|
|
|
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float()
|
|
|
|
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
|
|
|
return t_out.reshape(*x_in.shape).type_as(x_in)
|
2025-02-04 08:56:00 +00:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x: torch.Tensor,
|
|
|
|
x_mask: torch.Tensor,
|
|
|
|
freqs_cis: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x:
|
|
|
|
x_mask:
|
|
|
|
freqs_cis:
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
|
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
|
|
xq, xk, xv = torch.split(
|
|
|
|
self.qkv(x),
|
|
|
|
[
|
|
|
|
self.n_local_heads * self.head_dim,
|
|
|
|
self.n_local_kv_heads * self.head_dim,
|
|
|
|
self.n_local_kv_heads * self.head_dim,
|
|
|
|
],
|
|
|
|
dim=-1,
|
|
|
|
)
|
|
|
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
|
|
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
|
|
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
|
|
|
|
|
|
xq = self.q_norm(xq)
|
|
|
|
xk = self.k_norm(xk)
|
2025-02-05 09:16:59 +00:00
|
|
|
|
2025-02-04 08:56:00 +00:00
|
|
|
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
|
|
|
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
|
|
|
|
|
|
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
|
|
if n_rep >= 1:
|
|
|
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
|
|
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
|
|
|
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
|
|
|
|
|
|
|
return self.out(output)
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dim: int,
|
|
|
|
hidden_dim: int,
|
|
|
|
multiple_of: int,
|
|
|
|
ffn_dim_multiplier: Optional[float],
|
|
|
|
operation_settings={},
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Initialize the FeedForward module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (int): Input dimension.
|
|
|
|
hidden_dim (int): Hidden dimension of the feedforward layer.
|
|
|
|
multiple_of (int): Value to ensure hidden dimension is a multiple
|
|
|
|
of this value.
|
|
|
|
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
|
|
|
dimension. Defaults to None.
|
|
|
|
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
# custom dim factor multiplier
|
|
|
|
if ffn_dim_multiplier is not None:
|
|
|
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
|
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
|
|
|
|
|
|
self.w1 = operation_settings.get("operations").Linear(
|
|
|
|
dim,
|
|
|
|
hidden_dim,
|
|
|
|
bias=False,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
self.w2 = operation_settings.get("operations").Linear(
|
|
|
|
hidden_dim,
|
|
|
|
dim,
|
|
|
|
bias=False,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
self.w3 = operation_settings.get("operations").Linear(
|
|
|
|
dim,
|
|
|
|
hidden_dim,
|
|
|
|
bias=False,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
|
|
|
|
# @torch.compile
|
|
|
|
def _forward_silu_gating(self, x1, x3):
|
|
|
|
return F.silu(x1) * x3
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
|
|
|
|
|
|
|
|
|
|
|
class JointTransformerBlock(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
layer_id: int,
|
|
|
|
dim: int,
|
|
|
|
n_heads: int,
|
|
|
|
n_kv_heads: int,
|
|
|
|
multiple_of: int,
|
|
|
|
ffn_dim_multiplier: float,
|
|
|
|
norm_eps: float,
|
|
|
|
qk_norm: bool,
|
|
|
|
modulation=True,
|
|
|
|
operation_settings={},
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Initialize a TransformerBlock.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
layer_id (int): Identifier for the layer.
|
|
|
|
dim (int): Embedding dimension of the input features.
|
|
|
|
n_heads (int): Number of attention heads.
|
|
|
|
n_kv_heads (Optional[int]): Number of attention heads in key and
|
|
|
|
value features (if using GQA), or set to None for the same as
|
|
|
|
query.
|
|
|
|
multiple_of (int):
|
|
|
|
ffn_dim_multiplier (float):
|
|
|
|
norm_eps (float):
|
|
|
|
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
|
|
self.head_dim = dim // n_heads
|
|
|
|
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
|
|
|
self.feed_forward = FeedForward(
|
|
|
|
dim=dim,
|
|
|
|
hidden_dim=4 * dim,
|
|
|
|
multiple_of=multiple_of,
|
|
|
|
ffn_dim_multiplier=ffn_dim_multiplier,
|
|
|
|
operation_settings=operation_settings,
|
|
|
|
)
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
|
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
|
|
|
|
|
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
|
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
|
|
|
|
|
self.modulation = modulation
|
|
|
|
if modulation:
|
|
|
|
self.adaLN_modulation = nn.Sequential(
|
|
|
|
nn.SiLU(),
|
|
|
|
operation_settings.get("operations").Linear(
|
|
|
|
min(dim, 1024),
|
|
|
|
4 * dim,
|
|
|
|
bias=True,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x: torch.Tensor,
|
|
|
|
x_mask: torch.Tensor,
|
|
|
|
freqs_cis: torch.Tensor,
|
|
|
|
adaln_input: Optional[torch.Tensor]=None,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Perform a forward pass through the TransformerBlock.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (torch.Tensor): Input tensor.
|
|
|
|
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Output tensor after applying attention and
|
|
|
|
feedforward layers.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if self.modulation:
|
|
|
|
assert adaln_input is not None
|
|
|
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
|
|
|
|
|
|
|
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
|
|
|
self.attention(
|
|
|
|
modulate(self.attention_norm1(x), scale_msa),
|
|
|
|
x_mask,
|
|
|
|
freqs_cis,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
|
|
|
self.feed_forward(
|
|
|
|
modulate(self.ffn_norm1(x), scale_mlp),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
assert adaln_input is None
|
|
|
|
x = x + self.attention_norm2(
|
|
|
|
self.attention(
|
|
|
|
self.attention_norm1(x),
|
|
|
|
x_mask,
|
|
|
|
freqs_cis,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
x = x + self.ffn_norm2(
|
|
|
|
self.feed_forward(
|
|
|
|
self.ffn_norm1(x),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class FinalLayer(nn.Module):
|
|
|
|
"""
|
|
|
|
The final layer of NextDiT.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
|
|
|
super().__init__()
|
|
|
|
self.norm_final = operation_settings.get("operations").LayerNorm(
|
|
|
|
hidden_size,
|
|
|
|
elementwise_affine=False,
|
|
|
|
eps=1e-6,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
self.linear = operation_settings.get("operations").Linear(
|
|
|
|
hidden_size,
|
|
|
|
patch_size * patch_size * out_channels,
|
|
|
|
bias=True,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.adaLN_modulation = nn.Sequential(
|
|
|
|
nn.SiLU(),
|
|
|
|
operation_settings.get("operations").Linear(
|
|
|
|
min(hidden_size, 1024),
|
|
|
|
hidden_size,
|
|
|
|
bias=True,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x, c):
|
|
|
|
scale = self.adaLN_modulation(c)
|
|
|
|
x = modulate(self.norm_final(x), scale)
|
|
|
|
x = self.linear(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class RopeEmbedder:
|
|
|
|
def __init__(
|
|
|
|
self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.theta = theta
|
|
|
|
self.axes_dims = axes_dims
|
|
|
|
self.axes_lens = axes_lens
|
|
|
|
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
|
|
|
|
|
|
|
def __call__(self, ids: torch.Tensor):
|
|
|
|
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
|
|
|
result = []
|
|
|
|
for i in range(len(self.axes_dims)):
|
|
|
|
index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
|
|
|
|
result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
|
|
|
return torch.cat(result, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
class NextDiT(nn.Module):
|
|
|
|
"""
|
|
|
|
Diffusion model with a Transformer backbone.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
patch_size: int = 2,
|
|
|
|
in_channels: int = 4,
|
|
|
|
dim: int = 4096,
|
|
|
|
n_layers: int = 32,
|
|
|
|
n_refiner_layers: int = 2,
|
|
|
|
n_heads: int = 32,
|
|
|
|
n_kv_heads: Optional[int] = None,
|
|
|
|
multiple_of: int = 256,
|
|
|
|
ffn_dim_multiplier: Optional[float] = None,
|
|
|
|
norm_eps: float = 1e-5,
|
|
|
|
qk_norm: bool = False,
|
|
|
|
cap_feat_dim: int = 5120,
|
|
|
|
axes_dims: List[int] = (16, 56, 56),
|
|
|
|
axes_lens: List[int] = (1, 512, 512),
|
|
|
|
image_model=None,
|
|
|
|
device=None,
|
|
|
|
dtype=None,
|
|
|
|
operations=None,
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.dtype = dtype
|
|
|
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = in_channels
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
self.x_embedder = operation_settings.get("operations").Linear(
|
|
|
|
in_features=patch_size * patch_size * in_channels,
|
|
|
|
out_features=dim,
|
|
|
|
bias=True,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.noise_refiner = nn.ModuleList(
|
|
|
|
[
|
|
|
|
JointTransformerBlock(
|
|
|
|
layer_id,
|
|
|
|
dim,
|
|
|
|
n_heads,
|
|
|
|
n_kv_heads,
|
|
|
|
multiple_of,
|
|
|
|
ffn_dim_multiplier,
|
|
|
|
norm_eps,
|
|
|
|
qk_norm,
|
|
|
|
modulation=True,
|
|
|
|
operation_settings=operation_settings,
|
|
|
|
)
|
|
|
|
for layer_id in range(n_refiner_layers)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
self.context_refiner = nn.ModuleList(
|
|
|
|
[
|
|
|
|
JointTransformerBlock(
|
|
|
|
layer_id,
|
|
|
|
dim,
|
|
|
|
n_heads,
|
|
|
|
n_kv_heads,
|
|
|
|
multiple_of,
|
|
|
|
ffn_dim_multiplier,
|
|
|
|
norm_eps,
|
|
|
|
qk_norm,
|
|
|
|
modulation=False,
|
|
|
|
operation_settings=operation_settings,
|
|
|
|
)
|
|
|
|
for layer_id in range(n_refiner_layers)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
|
|
|
self.cap_embedder = nn.Sequential(
|
|
|
|
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
|
|
|
operation_settings.get("operations").Linear(
|
|
|
|
cap_feat_dim,
|
|
|
|
dim,
|
|
|
|
bias=True,
|
|
|
|
device=operation_settings.get("device"),
|
|
|
|
dtype=operation_settings.get("dtype"),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList(
|
|
|
|
[
|
|
|
|
JointTransformerBlock(
|
|
|
|
layer_id,
|
|
|
|
dim,
|
|
|
|
n_heads,
|
|
|
|
n_kv_heads,
|
|
|
|
multiple_of,
|
|
|
|
ffn_dim_multiplier,
|
|
|
|
norm_eps,
|
|
|
|
qk_norm,
|
|
|
|
operation_settings=operation_settings,
|
|
|
|
)
|
|
|
|
for layer_id in range(n_layers)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
|
|
|
|
|
|
|
assert (dim // n_heads) == sum(axes_dims)
|
|
|
|
self.axes_dims = axes_dims
|
|
|
|
self.axes_lens = axes_lens
|
2025-02-05 09:16:59 +00:00
|
|
|
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
|
|
|
|
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
2025-02-04 08:56:00 +00:00
|
|
|
self.dim = dim
|
|
|
|
self.n_heads = n_heads
|
|
|
|
|
|
|
|
def unpatchify(
|
|
|
|
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
"""
|
|
|
|
x: (N, T, patch_size**2 * C)
|
|
|
|
imgs: (N, H, W, C)
|
|
|
|
"""
|
|
|
|
pH = pW = self.patch_size
|
|
|
|
imgs = []
|
|
|
|
for i in range(x.size(0)):
|
|
|
|
H, W = img_size[i]
|
|
|
|
begin = cap_size[i]
|
|
|
|
end = begin + (H // pH) * (W // pW)
|
|
|
|
imgs.append(
|
|
|
|
x[i][begin:end]
|
|
|
|
.view(H // pH, W // pW, pH, pW, self.out_channels)
|
|
|
|
.permute(4, 0, 2, 1, 3)
|
|
|
|
.flatten(3, 4)
|
|
|
|
.flatten(1, 2)
|
|
|
|
)
|
|
|
|
|
|
|
|
if return_tensor:
|
|
|
|
imgs = torch.stack(imgs, dim=0)
|
|
|
|
return imgs
|
|
|
|
|
|
|
|
def patchify_and_embed(
|
|
|
|
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
|
|
|
bsz = len(x)
|
|
|
|
pH = pW = self.patch_size
|
|
|
|
device = x[0].device
|
|
|
|
dtype = x[0].dtype
|
|
|
|
|
|
|
|
if cap_mask is not None:
|
|
|
|
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
|
|
|
else:
|
|
|
|
l_effective_cap_len = [num_tokens] * bsz
|
|
|
|
|
|
|
|
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
|
|
|
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
|
|
|
|
|
|
|
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
|
|
|
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
|
|
|
|
|
|
|
max_seq_len = max(
|
|
|
|
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
|
|
|
)
|
|
|
|
max_cap_len = max(l_effective_cap_len)
|
|
|
|
max_img_len = max(l_effective_img_len)
|
|
|
|
|
|
|
|
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
|
|
|
|
|
|
|
for i in range(bsz):
|
|
|
|
cap_len = l_effective_cap_len[i]
|
|
|
|
img_len = l_effective_img_len[i]
|
|
|
|
H, W = img_sizes[i]
|
|
|
|
H_tokens, W_tokens = H // pH, W // pW
|
|
|
|
assert H_tokens * W_tokens == img_len
|
|
|
|
|
|
|
|
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
|
|
|
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
|
|
|
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
|
|
|
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
|
|
|
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
|
|
|
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
|
|
|
|
2025-02-05 09:16:59 +00:00
|
|
|
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
2025-02-04 08:56:00 +00:00
|
|
|
|
|
|
|
# build freqs_cis for cap and image individually
|
|
|
|
cap_freqs_cis_shape = list(freqs_cis.shape)
|
|
|
|
# cap_freqs_cis_shape[1] = max_cap_len
|
|
|
|
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
|
|
|
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
|
|
|
|
|
|
|
img_freqs_cis_shape = list(freqs_cis.shape)
|
|
|
|
img_freqs_cis_shape[1] = max_img_len
|
|
|
|
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
|
|
|
|
|
|
|
for i in range(bsz):
|
|
|
|
cap_len = l_effective_cap_len[i]
|
|
|
|
img_len = l_effective_img_len[i]
|
|
|
|
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
|
|
|
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
|
|
|
|
|
|
|
# refine context
|
|
|
|
for layer in self.context_refiner:
|
|
|
|
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
|
|
|
|
|
|
|
# refine image
|
|
|
|
flat_x = []
|
|
|
|
for i in range(bsz):
|
|
|
|
img = x[i]
|
|
|
|
C, H, W = img.size()
|
|
|
|
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
|
|
|
flat_x.append(img)
|
|
|
|
x = flat_x
|
|
|
|
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
2025-02-05 02:48:11 +00:00
|
|
|
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
2025-02-04 08:56:00 +00:00
|
|
|
for i in range(bsz):
|
|
|
|
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
2025-02-05 02:48:11 +00:00
|
|
|
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
2025-02-04 08:56:00 +00:00
|
|
|
|
|
|
|
padded_img_embed = self.x_embedder(padded_img_embed)
|
2025-02-05 02:48:11 +00:00
|
|
|
padded_img_mask = padded_img_mask.unsqueeze(1)
|
2025-02-04 08:56:00 +00:00
|
|
|
for layer in self.noise_refiner:
|
|
|
|
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
|
|
|
|
|
|
|
if cap_mask is not None:
|
|
|
|
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
|
|
|
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
|
|
|
else:
|
|
|
|
mask = None
|
|
|
|
|
|
|
|
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
|
|
|
for i in range(bsz):
|
|
|
|
cap_len = l_effective_cap_len[i]
|
|
|
|
img_len = l_effective_img_len[i]
|
|
|
|
|
|
|
|
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
|
|
|
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
|
|
|
|
|
|
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
|
|
|
|
|
|
|
|
|
|
|
# def forward(self, x, t, cap_feats, cap_mask):
|
|
|
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
|
|
|
t = 1.0 - timesteps
|
|
|
|
cap_feats = context
|
|
|
|
cap_mask = attention_mask
|
|
|
|
"""
|
|
|
|
Forward pass of NextDiT.
|
|
|
|
t: (N,) tensor of diffusion timesteps
|
|
|
|
y: (N,) tensor of text tokens/features
|
|
|
|
"""
|
|
|
|
|
|
|
|
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
|
|
|
adaln_input = t
|
|
|
|
|
|
|
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
|
|
|
|
|
|
|
x_is_tensor = isinstance(x, torch.Tensor)
|
|
|
|
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
|
|
|
freqs_cis = freqs_cis.to(x.device)
|
|
|
|
|
|
|
|
for layer in self.layers:
|
|
|
|
x = layer(x, mask, freqs_cis, adaln_input)
|
|
|
|
|
|
|
|
x = self.final_layer(x, adaln_input)
|
|
|
|
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
|
|
|
|
|
|
|
|
return -x
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def precompute_freqs_cis(
|
|
|
|
dim: List[int],
|
|
|
|
end: List[int],
|
|
|
|
theta: float = 10000.0,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Precompute the frequency tensor for complex exponentials (cis) with
|
|
|
|
given dimensions.
|
|
|
|
|
|
|
|
This function calculates a frequency tensor with complex exponentials
|
|
|
|
using the given dimension 'dim' and the end index 'end'. The 'theta'
|
|
|
|
parameter scales the frequencies. The returned tensor contains complex
|
|
|
|
values in complex64 data type.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (list): Dimension of the frequency tensor.
|
|
|
|
end (list): End index for precomputing frequencies.
|
|
|
|
theta (float, optional): Scaling factor for frequency computation.
|
|
|
|
Defaults to 10000.0.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Precomputed frequency tensor with complex
|
|
|
|
exponentials.
|
|
|
|
"""
|
|
|
|
freqs_cis = []
|
|
|
|
for i, (d, e) in enumerate(zip(dim, end)):
|
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
|
|
|
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
|
|
|
freqs = torch.outer(timestep, freqs).float()
|
|
|
|
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
|
|
|
freqs_cis.append(freqs_cis_i)
|
|
|
|
|
|
|
|
return freqs_cis
|