mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Compare commits
3 Commits
dd15653b2a
...
bde5fe2b99
Author | SHA1 | Date | |
---|---|---|---|
|
bde5fe2b99 | ||
|
2ff3104f70 | ||
|
129d8908f7 |
@ -382,3 +382,7 @@ class HunyuanVideo(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||||
|
|
||||||
|
class Cosmos1CV8x8x8(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
latent_dimensions = 3
|
||||||
|
804
comfy/ldm/cosmos/blocks.py
Normal file
804
comfy/ldm/cosmos/blocks.py
Normal file
@ -0,0 +1,804 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
t: torch.Tensor,
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||||
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||||
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||||
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
|
def get_normalization(name: str, channels: int, weight_args={}):
|
||||||
|
if name == "I":
|
||||||
|
return nn.Identity()
|
||||||
|
elif name == "R":
|
||||||
|
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization {name} not found")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAttentionOp(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Generalized attention impl.
|
||||||
|
|
||||||
|
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
||||||
|
If `context_dim` is None, self-attention is assumed.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query_dim (int): Dimension of each query vector.
|
||||||
|
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
||||||
|
heads (int, optional): Number of attention heads. Defaults to 8.
|
||||||
|
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
||||||
|
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
||||||
|
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
||||||
|
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
||||||
|
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
||||||
|
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
||||||
|
Defaults to "SSI".
|
||||||
|
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
||||||
|
Defaults to 'per_head'. Only support 'per_head'.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
||||||
|
>>> query = torch.randn(10, 128) # Batch size of 10
|
||||||
|
>>> context = torch.randn(10, 256) # Batch size of 10
|
||||||
|
>>> output = attn(query, context) # Perform the attention operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
context_dim=None,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
attn_op: Optional[BaseAttentionOp] = None,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
out_bias: bool = False,
|
||||||
|
qkv_norm: str = "SSI",
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
backend: str = "transformer_engine",
|
||||||
|
qkv_format: str = "bshd",
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_selfattn = context_dim is None # self attention
|
||||||
|
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.qkv_norm_mode = qkv_norm_mode
|
||||||
|
self.qkv_format = qkv_format
|
||||||
|
|
||||||
|
if self.qkv_norm_mode == "per_head":
|
||||||
|
norm_dim = dim_head
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||||
|
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
self.to_q = nn.Sequential(
|
||||||
|
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[0], norm_dim),
|
||||||
|
)
|
||||||
|
self.to_k = nn.Sequential(
|
||||||
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[1], norm_dim),
|
||||||
|
)
|
||||||
|
self.to_v = nn.Sequential(
|
||||||
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
|
get_normalization(qkv_norm[2], norm_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def cal_qkv(
|
||||||
|
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
||||||
|
Before 07/24/2024, these modules normalize across all heads.
|
||||||
|
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
||||||
|
we support to normalize per head.
|
||||||
|
To keep the checkpoint copatibility with the previous code,
|
||||||
|
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
||||||
|
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
||||||
|
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
||||||
|
"""
|
||||||
|
if self.qkv_norm_mode == "per_head":
|
||||||
|
q = self.to_q[0](x)
|
||||||
|
context = x if context is None else context
|
||||||
|
k = self.to_k[0](context)
|
||||||
|
v = self.to_v[0](context)
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
||||||
|
|
||||||
|
q = self.to_q[1](q)
|
||||||
|
k = self.to_k[1](k)
|
||||||
|
v = self.to_v[1](v)
|
||||||
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
|
q = apply_rotary_pos_emb(q, rope_emb)
|
||||||
|
k = apply_rotary_pos_emb(k, rope_emb)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def cal_attn(self, q, k, v, mask=None):
|
||||||
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||||
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context=None,
|
||||||
|
mask=None,
|
||||||
|
rope_emb=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||||
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
|
"""
|
||||||
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
|
return self.cal_attn(q, k, v, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer FFN with optional gating
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
d_model (int): Dimensionality of input features.
|
||||||
|
d_ff (int): Dimensionality of the hidden layer.
|
||||||
|
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
||||||
|
activation (callable, optional): The activation function applied after the first linear layer.
|
||||||
|
Defaults to nn.ReLU().
|
||||||
|
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
||||||
|
Defaults to False.
|
||||||
|
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
||||||
|
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
||||||
|
>>> output = ff(x)
|
||||||
|
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
d_ff: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation=nn.ReLU(),
|
||||||
|
is_gated: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
|
||||||
|
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.is_gated = is_gated
|
||||||
|
if is_gated:
|
||||||
|
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
g = self.activation(self.layer1(x))
|
||||||
|
if self.is_gated:
|
||||||
|
x = g * self.linear_gate(x)
|
||||||
|
else:
|
||||||
|
x = g
|
||||||
|
assert self.dropout.p == 0.0, "we skip dropout"
|
||||||
|
return self.layer2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2FeedForward(FeedForward):
|
||||||
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
d_model=d_model,
|
||||||
|
d_ff=d_ff,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=nn.GELU(),
|
||||||
|
is_gated=False,
|
||||||
|
bias=bias,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
assert self.dropout.p == 0.0, "we skip dropout"
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / (half_dim - 0.0)
|
||||||
|
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
logging.debug(
|
||||||
|
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||||
|
)
|
||||||
|
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
|
||||||
|
else:
|
||||||
|
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = self.linear_1(sample)
|
||||||
|
emb = self.activation(emb)
|
||||||
|
emb = self.linear_2(emb)
|
||||||
|
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
adaln_lora_B_3D = emb
|
||||||
|
emb_B_D = sample
|
||||||
|
else:
|
||||||
|
emb_B_D = emb
|
||||||
|
adaln_lora_B_3D = None
|
||||||
|
|
||||||
|
return emb_B_D, adaln_lora_B_3D
|
||||||
|
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
||||||
|
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
||||||
|
|
||||||
|
[B] -> [B, D]
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_channels (int): The number of Fourier features to generate.
|
||||||
|
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
||||||
|
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
||||||
|
the variance of the features. Defaults to False.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
||||||
|
>>> x = torch.randn(10, 256) # Example input tensor
|
||||||
|
>>> output = layer(x)
|
||||||
|
>>> print(output.shape) # Expected shape: (10, 256)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||||
|
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||||
|
self.gain = np.sqrt(2) if normalize else 1
|
||||||
|
|
||||||
|
def forward(self, x, gain: float = 1.0):
|
||||||
|
"""
|
||||||
|
Apply the Fourier feature transformation to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The transformed tensor, with Fourier features applied.
|
||||||
|
"""
|
||||||
|
in_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||||
|
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||||
|
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||||
|
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||||
|
and embedding each patch into a vector of size `out_channels`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- spatial_patch_size (int): The size of each spatial patch.
|
||||||
|
- temporal_patch_size (int): The size of each temporal patch.
|
||||||
|
- in_channels (int): Number of input channels. Default: 3.
|
||||||
|
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||||
|
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spatial_patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=768,
|
||||||
|
bias=True,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_patch_size = spatial_patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
Rearrange(
|
||||||
|
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||||
|
r=temporal_patch_size,
|
||||||
|
m=spatial_patch_size,
|
||||||
|
n=spatial_patch_size,
|
||||||
|
),
|
||||||
|
operations.Linear(
|
||||||
|
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.out = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the PatchEmbed module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||||
|
B is the batch size,
|
||||||
|
C is the number of channels,
|
||||||
|
T is the temporal dimension,
|
||||||
|
H is the height, and
|
||||||
|
W is the width of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 5
|
||||||
|
_, _, T, H, W = x.shape
|
||||||
|
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||||
|
assert T % self.temporal_patch_size == 0
|
||||||
|
x = self.proj(x)
|
||||||
|
return self.out(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of video DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
spatial_patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
out_channels,
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.n_adaln_chunks = 2
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_BT_HW_D,
|
||||||
|
emb_B_D,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
assert adaln_lora_B_3D is not None
|
||||||
|
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
||||||
|
2, dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
||||||
|
|
||||||
|
B = emb_B_D.shape[0]
|
||||||
|
T = x_BT_HW_D.shape[0] // B
|
||||||
|
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
||||||
|
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
||||||
|
|
||||||
|
x_BT_HW_D = self.linear(x_BT_HW_D)
|
||||||
|
return x_BT_HW_D
|
||||||
|
|
||||||
|
|
||||||
|
class VideoAttn(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements video attention with optional cross-attention capabilities.
|
||||||
|
|
||||||
|
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
||||||
|
self-attention within the video features or cross-attention with external context features.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input feature vectors
|
||||||
|
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
bias (bool): Whether to include bias in attention projections. Default: False
|
||||||
|
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
||||||
|
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
||||||
|
|
||||||
|
Input shape:
|
||||||
|
- x: (T, H, W, B, D) video features
|
||||||
|
- context (optional): (M, B, D) context features for cross-attention
|
||||||
|
where:
|
||||||
|
T: temporal dimension
|
||||||
|
H: height
|
||||||
|
W: width
|
||||||
|
B: batch size
|
||||||
|
D: feature dimension
|
||||||
|
M: context sequence length
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = False,
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
weight_args={},
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.x_format = x_format
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
x_dim // num_heads,
|
||||||
|
qkv_bias=bias,
|
||||||
|
qkv_norm="RRI",
|
||||||
|
out_bias=bias,
|
||||||
|
qkv_norm_mode=qkv_norm_mode,
|
||||||
|
qkv_format="sbhd",
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for video attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
||||||
|
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
||||||
|
where M is the sequence length of the context.
|
||||||
|
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
||||||
|
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||||
|
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor with applied attention, maintaining the input shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x_T_H_W_B_D = x
|
||||||
|
context_M_B_D = context
|
||||||
|
T, H, W, B, D = x_T_H_W_B_D.shape
|
||||||
|
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
||||||
|
x_THW_B_D = self.attn(
|
||||||
|
x_THW_B_D,
|
||||||
|
context_M_B_D,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||||
|
return x_T_H_W_B_D
|
||||||
|
|
||||||
|
|
||||||
|
def adaln_norm_state(norm_state, x, scale, shift):
|
||||||
|
normalized = norm_state(x)
|
||||||
|
return normalized * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
class DITBuildingBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
||||||
|
attention and MLP operations with adaptive layer normalization.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
block_type (str): Type of block - one of:
|
||||||
|
- "cross_attn"/"ca": Cross-attention
|
||||||
|
- "full_attn"/"fa": Full self-attention
|
||||||
|
- "mlp"/"ff": MLP/feedforward block
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (Optional[int]): Dimension of context features for cross-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||||
|
bias (bool): Whether to use bias in layers. Default: False
|
||||||
|
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
||||||
|
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
||||||
|
x_format (str): Input tensor format. Default: "BTHWD"
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_type: str,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: Optional[int],
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
bias: bool = False,
|
||||||
|
mlp_dropout: float = 0.0,
|
||||||
|
qkv_norm_mode: str = "per_head",
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None
|
||||||
|
) -> None:
|
||||||
|
block_type = block_type.lower()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.x_format = x_format
|
||||||
|
if block_type in ["cross_attn", "ca"]:
|
||||||
|
self.block = VideoAttn(
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
bias=bias,
|
||||||
|
qkv_norm_mode=qkv_norm_mode,
|
||||||
|
x_format=self.x_format,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
elif block_type in ["full_attn", "fa"]:
|
||||||
|
self.block = VideoAttn(
|
||||||
|
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
|
||||||
|
)
|
||||||
|
elif block_type in ["mlp", "ff"]:
|
||||||
|
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown block type: {block_type}")
|
||||||
|
|
||||||
|
self.block_type = block_type
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
|
||||||
|
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.n_adaln_chunks = 3
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
||||||
|
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
||||||
|
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
||||||
|
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
||||||
|
rope_emb_L_1_1_D (Optional[Tensor]):
|
||||||
|
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
||||||
|
"""
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
||||||
|
self.n_adaln_chunks, dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
||||||
|
|
||||||
|
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
||||||
|
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.block_type in ["mlp", "ff"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
)
|
||||||
|
elif self.block_type in ["full_attn", "fa"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
context=None,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
elif self.block_type in ["cross_attn", "ca"]:
|
||||||
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
|
context=crossattn_emb,
|
||||||
|
crossattn_mask=crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralDITTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
||||||
|
Each block in the sequence is specified by a block configuration string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (int): Dimension of context features for cross-attention blocks
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
||||||
|
full-attention, then MLP)
|
||||||
|
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
||||||
|
x_format (str): Input tensor format. Default: "BTHWD"
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
||||||
|
|
||||||
|
The block_config string uses "-" to separate block types:
|
||||||
|
- "ca"/"cross_attn": Cross-attention block
|
||||||
|
- "fa"/"full_attn": Full self-attention block
|
||||||
|
- "mlp"/"ff": MLP/feedforward block
|
||||||
|
|
||||||
|
Example:
|
||||||
|
block_config = "ca-fa-mlp" creates a sequence of:
|
||||||
|
1. Cross-attention block
|
||||||
|
2. Full self-attention block
|
||||||
|
3. MLP block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
block_config: str,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
x_format: str = "BTHWD",
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
weight_args={},
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
self.x_format = x_format
|
||||||
|
for block_type in block_config.split("-"):
|
||||||
|
self.blocks.append(
|
||||||
|
DITBuildingBlock(
|
||||||
|
block_type,
|
||||||
|
x_dim,
|
||||||
|
context_dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio,
|
||||||
|
x_format=self.x_format,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if extra_per_block_pos_emb is not None:
|
||||||
|
x = x + extra_per_block_pos_emb
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
emb_B_D,
|
||||||
|
crossattn_emb,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
return x
|
1050
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
1050
comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
Normal file
File diff suppressed because it is too large
Load Diff
355
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
355
comfy/ldm/cosmos/cosmos_tokenizer/patching.py
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The patcher and unpatcher implementation for 2D and 3D data.
|
||||||
|
|
||||||
|
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
|
||||||
|
One on the rows and one on the columns.
|
||||||
|
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
|
||||||
|
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
|
||||||
|
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
|
||||||
|
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
|
||||||
|
as we need to support downsampling for more than 2x.
|
||||||
|
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
|
||||||
|
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
_WAVELETS = {
|
||||||
|
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
||||||
|
"rearrange": torch.tensor([1.0, 1.0]),
|
||||||
|
}
|
||||||
|
_PERSISTENT = False
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher(torch.nn.Module):
|
||||||
|
"""A module to convert image tensors into patches using torch operations.
|
||||||
|
|
||||||
|
The main difference from `class Patching` is that this module implements
|
||||||
|
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||||
|
|
||||||
|
It's bit-wise identical to the Patching module outputs, with the added
|
||||||
|
benefit of being torch.jit scriptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_method = patch_method
|
||||||
|
self.register_buffer(
|
||||||
|
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||||
|
)
|
||||||
|
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||||
|
self.register_buffer(
|
||||||
|
"_arange",
|
||||||
|
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.patch_method == "haar":
|
||||||
|
return self._haar(x)
|
||||||
|
elif self.patch_method == "rearrange":
|
||||||
|
return self._arrange(x)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||||
|
|
||||||
|
def _dwt(self, x, mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
g = x.shape[1]
|
||||||
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
||||||
|
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
||||||
|
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
||||||
|
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
||||||
|
|
||||||
|
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
||||||
|
if rescale:
|
||||||
|
out = out / 2
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _haar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._dwt(x, rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _arrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (h p1) (w p2) -> b (c p1 p2) h w",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Patcher3D(Patcher):
|
||||||
|
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||||
|
self.register_buffer(
|
||||||
|
"patch_size_buffer",
|
||||||
|
patch_size * torch.ones([1], dtype=torch.int32),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
g = x.shape[1]
|
||||||
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
# Handles temporal axis.
|
||||||
|
x = F.pad(
|
||||||
|
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
|
||||||
|
).to(dtype)
|
||||||
|
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||||
|
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
||||||
|
|
||||||
|
# Handles spatial axes.
|
||||||
|
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
||||||
|
|
||||||
|
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
||||||
|
|
||||||
|
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
||||||
|
if rescale:
|
||||||
|
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _haar(self, x):
|
||||||
|
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||||
|
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._dwt(x, "haar", rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _arrange(self, x):
|
||||||
|
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
||||||
|
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
p3=self.patch_size,
|
||||||
|
).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UnPatcher(torch.nn.Module):
|
||||||
|
"""A module to convert patches into image tensorsusing torch operations.
|
||||||
|
|
||||||
|
The main difference from `class Unpatching` is that this module implements
|
||||||
|
all operations using torch, rather than python or numpy, for efficiency purpose.
|
||||||
|
|
||||||
|
It's bit-wise identical to the Unpatching module outputs, with the added
|
||||||
|
benefit of being torch.jit scriptable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_method = patch_method
|
||||||
|
self.register_buffer(
|
||||||
|
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
|
||||||
|
)
|
||||||
|
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
||||||
|
self.register_buffer(
|
||||||
|
"_arange",
|
||||||
|
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||||
|
persistent=_PERSISTENT,
|
||||||
|
)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.patch_method == "haar":
|
||||||
|
return self._ihaar(x)
|
||||||
|
elif self.patch_method == "rearrange":
|
||||||
|
return self._iarrange(x)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown patch method: " + self.patch_method)
|
||||||
|
|
||||||
|
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
n = h.shape[0]
|
||||||
|
|
||||||
|
g = x.shape[1] // 4
|
||||||
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
|
||||||
|
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
||||||
|
|
||||||
|
# Inverse transform.
|
||||||
|
yl = torch.nn.functional.conv_transpose2d(
|
||||||
|
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yl += torch.nn.functional.conv_transpose2d(
|
||||||
|
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yh = torch.nn.functional.conv_transpose2d(
|
||||||
|
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
yh += torch.nn.functional.conv_transpose2d(
|
||||||
|
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
|
||||||
|
)
|
||||||
|
y = torch.nn.functional.conv_transpose2d(
|
||||||
|
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||||
|
)
|
||||||
|
y += torch.nn.functional.conv_transpose2d(
|
||||||
|
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale:
|
||||||
|
y = y * 2
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _ihaar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._idwt(x, "haar", rescale=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _iarrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UnPatcher3D(UnPatcher):
|
||||||
|
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=1, patch_method="haar"):
|
||||||
|
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
||||||
|
|
||||||
|
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
||||||
|
dtype = x.dtype
|
||||||
|
h = self.wavelets.to(device=x.device)
|
||||||
|
|
||||||
|
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
||||||
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
||||||
|
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
|
||||||
|
hl = hl.to(dtype=dtype)
|
||||||
|
hh = hh.to(dtype=dtype)
|
||||||
|
|
||||||
|
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||||
|
|
||||||
|
# Height height transposed convolutions.
|
||||||
|
xll = F.conv_transpose3d(
|
||||||
|
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
xll += F.conv_transpose3d(
|
||||||
|
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
xlh = F.conv_transpose3d(
|
||||||
|
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
xlh += F.conv_transpose3d(
|
||||||
|
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
xhl = F.conv_transpose3d(
|
||||||
|
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
xhl += F.conv_transpose3d(
|
||||||
|
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
xhh = F.conv_transpose3d(
|
||||||
|
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
xhh += F.conv_transpose3d(
|
||||||
|
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handles width transposed convolutions.
|
||||||
|
xl = F.conv_transpose3d(
|
||||||
|
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
xl += F.conv_transpose3d(
|
||||||
|
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
xh = F.conv_transpose3d(
|
||||||
|
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
xh += F.conv_transpose3d(
|
||||||
|
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handles time axis transposed convolutions.
|
||||||
|
x = F.conv_transpose3d(
|
||||||
|
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
|
)
|
||||||
|
x += F.conv_transpose3d(
|
||||||
|
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale:
|
||||||
|
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _ihaar(self, x):
|
||||||
|
for _ in self.range:
|
||||||
|
x = self._idwt(x, "haar", rescale=True)
|
||||||
|
x = x[:, :, self.patch_size - 1 :, ...]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _iarrange(self, x):
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
||||||
|
p1=self.patch_size,
|
||||||
|
p2=self.patch_size,
|
||||||
|
p3=self.patch_size,
|
||||||
|
)
|
||||||
|
x = x[:, :, self.patch_size - 1 :, ...]
|
||||||
|
return x
|
120
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
120
comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Shared utilities for the networks module."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import pack, rearrange, unpack
|
||||||
|
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||||
|
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
|
batch_size, height = x.shape[0], x.shape[-2]
|
||||||
|
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
||||||
|
|
||||||
|
|
||||||
|
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
||||||
|
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(t: Any, length: int = 1) -> Any:
|
||||||
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
def replication_pad(x):
|
||||||
|
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def divisible_by(num: int, den: int) -> bool:
|
||||||
|
return (num % den) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_odd(n: int) -> bool:
|
||||||
|
return not divisible_by(n, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def nonlinearity(x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
return ops.GroupNorm(
|
||||||
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalNormalize(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, num_groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = ops.GroupNorm(
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_channels=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
)
|
||||||
|
self.num_groups = num_groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
||||||
|
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
||||||
|
if self.num_groups == 1:
|
||||||
|
x, batch_size = time2batch(x)
|
||||||
|
return batch2time(self.norm(x), batch_size)
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(*args):
|
||||||
|
for arg in args:
|
||||||
|
if exists(arg):
|
||||||
|
return arg
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def pack_one(t, pattern):
|
||||||
|
return pack([t], pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_one(t, ps, pattern):
|
||||||
|
return unpack(t, ps, pattern)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Round with straight through gradients."""
|
||||||
|
zhat = z.round()
|
||||||
|
return z + (zhat - z).detach()
|
||||||
|
|
||||||
|
|
||||||
|
def log(t, eps=1e-5):
|
||||||
|
return t.clamp(min=eps).log()
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(prob):
|
||||||
|
return (-prob * log(prob)).sum(dim=-1)
|
510
comfy/ldm/cosmos/model.py
Normal file
510
comfy/ldm/cosmos/model.py
Normal file
@ -0,0 +1,510 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
|
|
||||||
|
from .blocks import (
|
||||||
|
FinalLayer,
|
||||||
|
GeneralDITTransformerBlock,
|
||||||
|
PatchEmbed,
|
||||||
|
TimestepEmbedding,
|
||||||
|
Timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralDIT(nn.Module):
|
||||||
|
"""
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_img_h (int): Maximum height of the input images.
|
||||||
|
max_img_w (int): Maximum width of the input images.
|
||||||
|
max_frames (int): Maximum number of frames in the video sequence.
|
||||||
|
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||||
|
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||||
|
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||||
|
block_config (str): Configuration of the transformer block. See Notes for supported block types.
|
||||||
|
model_channels (int): Base number of channels used throughout the model.
|
||||||
|
num_blocks (int): Number of transformer blocks.
|
||||||
|
num_heads (int): Number of heads in the multi-head attention layers.
|
||||||
|
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||||
|
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
|
||||||
|
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||||
|
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
|
||||||
|
pos_emb_cls (str): Type of positional embeddings.
|
||||||
|
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||||
|
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||||
|
affline_emb_norm (bool): Whether to normalize affine embeddings.
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||||
|
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||||
|
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||||
|
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||||
|
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||||
|
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
|
||||||
|
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||||
|
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||||
|
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Supported block types in block_config:
|
||||||
|
* cross_attn, ca: Cross attention
|
||||||
|
* full_attn: Full attention on all flattened tokens
|
||||||
|
* mlp, ff: Feed forward block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_img_h: int,
|
||||||
|
max_img_w: int,
|
||||||
|
max_frames: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_spatial: tuple,
|
||||||
|
patch_temporal: int,
|
||||||
|
concat_padding_mask: bool = True,
|
||||||
|
# attention settings
|
||||||
|
block_config: str = "FA-CA-MLP",
|
||||||
|
model_channels: int = 768,
|
||||||
|
num_blocks: int = 10,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
block_x_format: str = "BTHWD",
|
||||||
|
# cross attention settings
|
||||||
|
crossattn_emb_channels: int = 1024,
|
||||||
|
use_cross_attn_mask: bool = False,
|
||||||
|
# positional embedding settings
|
||||||
|
pos_emb_cls: str = "sincos",
|
||||||
|
pos_emb_learnable: bool = False,
|
||||||
|
pos_emb_interpolation: str = "crop",
|
||||||
|
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
rope_h_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_w_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_t_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_per_block_abs_pos_emb: bool = False,
|
||||||
|
extra_per_block_abs_pos_emb_type: str = "sincos",
|
||||||
|
extra_h_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_w_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_t_extrapolation_ratio: float = 1.0,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_img_h = max_img_h
|
||||||
|
self.max_img_w = max_img_w
|
||||||
|
self.max_frames = max_frames
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_spatial = patch_spatial
|
||||||
|
self.patch_temporal = patch_temporal
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.use_cross_attn_mask = use_cross_attn_mask
|
||||||
|
self.concat_padding_mask = concat_padding_mask
|
||||||
|
# positional embedding settings
|
||||||
|
self.pos_emb_cls = pos_emb_cls
|
||||||
|
self.pos_emb_learnable = pos_emb_learnable
|
||||||
|
self.pos_emb_interpolation = pos_emb_interpolation
|
||||||
|
self.affline_emb_norm = affline_emb_norm
|
||||||
|
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||||
|
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||||
|
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||||
|
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||||
|
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
|
||||||
|
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||||
|
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||||
|
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||||
|
self.dtype = dtype
|
||||||
|
weight_args = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
spatial_patch_size=patch_spatial,
|
||||||
|
temporal_patch_size=patch_temporal,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=model_channels,
|
||||||
|
bias=False,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.build_pos_embed(device=device)
|
||||||
|
self.block_x_format = block_x_format
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
|
self.t_embedder = nn.ModuleList(
|
||||||
|
[Timesteps(model_channels),
|
||||||
|
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleDict()
|
||||||
|
|
||||||
|
for idx in range(num_blocks):
|
||||||
|
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
|
||||||
|
x_dim=model_channels,
|
||||||
|
context_dim=crossattn_emb_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
block_config=block_config,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
x_format=self.block_x_format,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.affline_emb_norm:
|
||||||
|
logging.debug("Building affine embedding normalization layer")
|
||||||
|
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
else:
|
||||||
|
self.affline_norm = nn.Identity()
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size=self.model_channels,
|
||||||
|
spatial_patch_size=self.patch_spatial,
|
||||||
|
temporal_patch_size=self.patch_temporal,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
use_adaln_lora=self.use_adaln_lora,
|
||||||
|
adaln_lora_dim=self.adaln_lora_dim,
|
||||||
|
weight_args=weight_args,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_pos_embed(self, device=None):
|
||||||
|
if self.pos_emb_cls == "rope3d":
|
||||||
|
cls_type = VideoRopePosition3DEmb
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||||
|
|
||||||
|
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||||
|
kwargs = dict(
|
||||||
|
model_channels=self.model_channels,
|
||||||
|
len_h=self.max_img_h // self.patch_spatial,
|
||||||
|
len_w=self.max_img_w // self.patch_spatial,
|
||||||
|
len_t=self.max_frames // self.patch_temporal,
|
||||||
|
is_learnable=self.pos_emb_learnable,
|
||||||
|
interpolation=self.pos_emb_interpolation,
|
||||||
|
head_dim=self.model_channels // self.num_heads,
|
||||||
|
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||||
|
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||||
|
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.pos_embedder = cls_type(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
assert self.extra_per_block_abs_pos_emb_type in [
|
||||||
|
"learnable",
|
||||||
|
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
|
||||||
|
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||||
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
|
kwargs["device"] = device
|
||||||
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_embedded_sequence(
|
||||||
|
self,
|
||||||
|
x_B_C_T_H_W: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_B_C_T_H_W (torch.Tensor): video
|
||||||
|
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||||
|
If None, a default value (`self.base_fps`) will be used.
|
||||||
|
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||||
|
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||||
|
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||||
|
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||||
|
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||||
|
the `self.pos_embedder` with the shape [T, H, W].
|
||||||
|
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||||
|
`self.pos_embedder` with the fps tensor.
|
||||||
|
- Otherwise, the positional embeddings are generated without considering fps.
|
||||||
|
"""
|
||||||
|
if self.concat_padding_mask:
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = transforms.functional.resize(
|
||||||
|
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||||
|
|
||||||
|
x_B_C_T_H_W = torch.cat(
|
||||||
|
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||||
|
)
|
||||||
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
|
||||||
|
else:
|
||||||
|
extra_pos_emb = None
|
||||||
|
|
||||||
|
if "rope" in self.pos_emb_cls.lower():
|
||||||
|
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||||
|
|
||||||
|
if "fps_aware" in self.pos_emb_cls:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
else:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
|
||||||
|
return x_B_T_H_W_D, None, extra_pos_emb
|
||||||
|
|
||||||
|
def decoder_head(
|
||||||
|
self,
|
||||||
|
x_B_T_H_W_D: torch.Tensor,
|
||||||
|
emb_B_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
del crossattn_emb, crossattn_mask
|
||||||
|
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
|
||||||
|
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
|
||||||
|
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
|
||||||
|
# This is to ensure x_BT_HW_D has the correct shape because
|
||||||
|
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
|
||||||
|
x_BT_HW_D = x_BT_HW_D.view(
|
||||||
|
B * T_before_patchify // self.patch_temporal,
|
||||||
|
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
x_B_D_T_H_W = rearrange(
|
||||||
|
x_BT_HW_D,
|
||||||
|
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||||
|
p1=self.patch_spatial,
|
||||||
|
p2=self.patch_spatial,
|
||||||
|
H=H_before_patchify // self.patch_spatial,
|
||||||
|
W=W_before_patchify // self.patch_spatial,
|
||||||
|
t=self.patch_temporal,
|
||||||
|
B=B,
|
||||||
|
)
|
||||||
|
return x_B_D_T_H_W
|
||||||
|
|
||||||
|
def forward_before_blocks(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||||
|
"""
|
||||||
|
del kwargs
|
||||||
|
assert isinstance(
|
||||||
|
data_type, DataType
|
||||||
|
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
|
||||||
|
original_shape = x.shape
|
||||||
|
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||||
|
x,
|
||||||
|
fps=fps,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
latent_condition=latent_condition,
|
||||||
|
latent_condition_sigma=latent_condition_sigma,
|
||||||
|
)
|
||||||
|
# logging affline scale information
|
||||||
|
affline_scale_log_info = {}
|
||||||
|
|
||||||
|
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
|
||||||
|
affline_emb_B_D = timesteps_B_D
|
||||||
|
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
|
||||||
|
|
||||||
|
if scalar_feature is not None:
|
||||||
|
raise NotImplementedError("Scalar feature is not implemented yet.")
|
||||||
|
|
||||||
|
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
|
||||||
|
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
|
||||||
|
|
||||||
|
if self.use_cross_attn_mask:
|
||||||
|
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
|
||||||
|
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||||
|
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
|
||||||
|
else:
|
||||||
|
crossattn_mask = None
|
||||||
|
|
||||||
|
if self.blocks["block0"].x_format == "THWBD":
|
||||||
|
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
|
||||||
|
)
|
||||||
|
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
|
||||||
|
|
||||||
|
if crossattn_mask:
|
||||||
|
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
|
||||||
|
|
||||||
|
elif self.blocks["block0"].x_format == "BTHWD":
|
||||||
|
x = x_B_T_H_W_D
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
|
||||||
|
output = {
|
||||||
|
"x": x,
|
||||||
|
"affline_emb_B_D": affline_emb_B_D,
|
||||||
|
"crossattn_emb": crossattn_emb,
|
||||||
|
"crossattn_mask": crossattn_mask,
|
||||||
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
|
||||||
|
"adaln_lora_B_3D": adaln_lora_B_3D,
|
||||||
|
"original_shape": original_shape,
|
||||||
|
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# crossattn_emb: torch.Tensor,
|
||||||
|
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
crossattn_mask: (B, N) tensor of cross-attention masks
|
||||||
|
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
|
||||||
|
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
|
||||||
|
we need forward_before_blocks pass to the forward_before_blocks function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
crossattn_emb = context
|
||||||
|
crossattn_mask = attention_mask
|
||||||
|
|
||||||
|
inputs = self.forward_before_blocks(
|
||||||
|
x=x,
|
||||||
|
timesteps=timesteps,
|
||||||
|
crossattn_emb=crossattn_emb,
|
||||||
|
crossattn_mask=crossattn_mask,
|
||||||
|
fps=fps,
|
||||||
|
image_size=image_size,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
scalar_feature=scalar_feature,
|
||||||
|
data_type=data_type,
|
||||||
|
latent_condition=latent_condition,
|
||||||
|
latent_condition_sigma=latent_condition_sigma,
|
||||||
|
condition_video_augment_sigma=condition_video_augment_sigma,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
|
||||||
|
inputs["x"],
|
||||||
|
inputs["affline_emb_B_D"],
|
||||||
|
inputs["crossattn_emb"],
|
||||||
|
inputs["crossattn_mask"],
|
||||||
|
inputs["rope_emb_L_1_1_D"],
|
||||||
|
inputs["adaln_lora_B_3D"],
|
||||||
|
inputs["original_shape"],
|
||||||
|
)
|
||||||
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
assert (
|
||||||
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
|
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||||
|
|
||||||
|
for _, block in self.blocks.items():
|
||||||
|
assert (
|
||||||
|
self.blocks["block0"].x_format == block.x_format
|
||||||
|
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||||
|
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
affline_emb_B_D,
|
||||||
|
crossattn_emb,
|
||||||
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
|
||||||
|
x_B_D_T_H_W = self.decoder_head(
|
||||||
|
x_B_T_H_W_D=x_B_T_H_W_D,
|
||||||
|
emb_B_D=affline_emb_B_D,
|
||||||
|
crossattn_emb=None,
|
||||||
|
origin_shape=original_shape,
|
||||||
|
crossattn_mask=None,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x_B_D_T_H_W
|
207
comfy/ldm/cosmos/position_embedding.py
Normal file
207
comfy/ldm/cosmos/position_embedding.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor to normalize.
|
||||||
|
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
||||||
|
eps (float, optional): A small constant to ensure numerical stability during division.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
"""
|
||||||
|
if dim is None:
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||||
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||||
|
return x / norm.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPositionEmb(nn.Module):
|
||||||
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
It delegates the embedding generation to generate_embeddings function.
|
||||||
|
"""
|
||||||
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||||
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*, # enforce keyword arguments
|
||||||
|
head_dim: int,
|
||||||
|
len_h: int,
|
||||||
|
len_w: int,
|
||||||
|
len_t: int,
|
||||||
|
base_fps: int = 24,
|
||||||
|
h_extrapolation_ratio: float = 1.0,
|
||||||
|
w_extrapolation_ratio: float = 1.0,
|
||||||
|
t_extrapolation_ratio: float = 1.0,
|
||||||
|
device=None,
|
||||||
|
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||||
|
):
|
||||||
|
del kwargs
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||||
|
self.base_fps = base_fps
|
||||||
|
self.max_h = len_h
|
||||||
|
self.max_w = len_w
|
||||||
|
|
||||||
|
dim = head_dim
|
||||||
|
dim_h = dim // 6 * 2
|
||||||
|
dim_w = dim_h
|
||||||
|
dim_t = dim - 2 * dim_h
|
||||||
|
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
||||||
|
self.register_buffer(
|
||||||
|
"dim_spatial_range",
|
||||||
|
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"dim_temporal_range",
|
||||||
|
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
||||||
|
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
||||||
|
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self,
|
||||||
|
B_T_H_W_C: torch.Size,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
h_ntk_factor: Optional[float] = None,
|
||||||
|
w_ntk_factor: Optional[float] = None,
|
||||||
|
t_ntk_factor: Optional[float] = None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate embeddings for the given input size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
||||||
|
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
||||||
|
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
||||||
|
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
||||||
|
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Not specified in the original code snippet.
|
||||||
|
"""
|
||||||
|
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
||||||
|
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
||||||
|
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
||||||
|
|
||||||
|
h_theta = 10000.0 * h_ntk_factor
|
||||||
|
w_theta = 10000.0 * w_ntk_factor
|
||||||
|
t_theta = 10000.0 * t_ntk_factor
|
||||||
|
|
||||||
|
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
||||||
|
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
||||||
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||||
|
|
||||||
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||||
|
assert (
|
||||||
|
uniform_fps or B == 1 or T == 1
|
||||||
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||||
|
assert (
|
||||||
|
H <= self.max_h and W <= self.max_w
|
||||||
|
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||||
|
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||||
|
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||||
|
|
||||||
|
# apply sequence scaling in temporal dimension
|
||||||
|
if fps is None: # image case
|
||||||
|
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||||
|
else:
|
||||||
|
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||||
|
|
||||||
|
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||||
|
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||||
|
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
||||||
|
|
||||||
|
em_T_H_W_D = torch.cat(
|
||||||
|
[
|
||||||
|
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
||||||
|
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
||||||
|
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
||||||
|
]
|
||||||
|
, dim=-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
||||||
|
|
||||||
|
|
||||||
|
class LearnablePosEmbAxis(VideoPositionEmb):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*, # enforce keyword arguments
|
||||||
|
interpolation: str,
|
||||||
|
model_channels: int,
|
||||||
|
len_h: int,
|
||||||
|
len_w: int,
|
||||||
|
len_t: int,
|
||||||
|
device=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
||||||
|
"""
|
||||||
|
del kwargs # unused
|
||||||
|
super().__init__()
|
||||||
|
self.interpolation = interpolation
|
||||||
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||||
|
|
||||||
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
||||||
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
||||||
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||||
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
if self.interpolation == "crop":
|
||||||
|
emb_h_H = self.pos_emb_h[:H].to(device=device)
|
||||||
|
emb_w_W = self.pos_emb_w[:W].to(device=device)
|
||||||
|
emb_t_T = self.pos_emb_t[:T].to(device=device)
|
||||||
|
emb = (
|
||||||
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||||
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||||
|
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
||||||
|
)
|
||||||
|
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
||||||
|
|
||||||
|
return normalize(emb, dim=-1, eps=1e-6)
|
124
comfy/ldm/cosmos/vae.py
Normal file
124
comfy/ldm/cosmos/vae.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from .cosmos_tokenizer.layers3d import (
|
||||||
|
EncoderFactorized,
|
||||||
|
DecoderFactorized,
|
||||||
|
CausalConv3d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityDistribution(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, parameters):
|
||||||
|
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDistribution(torch.nn.Module):
|
||||||
|
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
||||||
|
super().__init__()
|
||||||
|
self.min_logvar = min_logvar
|
||||||
|
self.max_logvar = max_logvar
|
||||||
|
|
||||||
|
def sample(self, mean, logvar):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
return mean + std * torch.randn_like(mean)
|
||||||
|
|
||||||
|
def forward(self, parameters):
|
||||||
|
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
||||||
|
return self.sample(mean, logvar), (mean, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousFormulation(Enum):
|
||||||
|
VAE = GaussianDistribution
|
||||||
|
AE = IdentityDistribution
|
||||||
|
|
||||||
|
|
||||||
|
class CausalContinuousVideoTokenizer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.sigma_data = 0.5
|
||||||
|
|
||||||
|
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
|
||||||
|
self.encoder = EncoderFactorized(
|
||||||
|
z_channels=z_factor * z_channels, **kwargs
|
||||||
|
)
|
||||||
|
if kwargs.get("temporal_compression", 4) == 4:
|
||||||
|
kwargs["channels_mult"] = [2, 4]
|
||||||
|
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
|
||||||
|
self.decoder = DecoderFactorized(
|
||||||
|
z_channels=z_channels, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = CausalConv3d(
|
||||||
|
z_factor * z_channels,
|
||||||
|
z_factor * latent_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.post_quant_conv = CausalConv3d(
|
||||||
|
latent_channels, z_channels, kernel_size=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
||||||
|
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||||
|
|
||||||
|
num_parameters = sum(param.numel() for param in self.parameters())
|
||||||
|
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||||
|
logging.info(
|
||||||
|
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_temporal_chunk = 16
|
||||||
|
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||||
|
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
z, posteriors = self.distribution(moments)
|
||||||
|
latent_ch = z.shape[1]
|
||||||
|
latent_t = z.shape[2]
|
||||||
|
dtype = z.dtype
|
||||||
|
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
||||||
|
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
||||||
|
return ((z - mean) / std) * self.sigma_data
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
in_dtype = z.dtype
|
||||||
|
latent_ch = z.shape[1]
|
||||||
|
latent_t = z.shape[2]
|
||||||
|
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
|
||||||
|
z = z / self.sigma_data
|
||||||
|
z = z * std + mean
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
return self.decoder(z)
|
||||||
|
|
@ -89,7 +89,7 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -142,6 +142,13 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
|
|
||||||
|
if skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.unsqueeze(0)
|
||||||
.reshape(b, heads, -1, dim_head)
|
.reshape(b, heads, -1, dim_head)
|
||||||
@ -151,7 +158,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
if skip_output_reshape:
|
||||||
|
hidden_states = hidden_states.unflatten(0, (-1, heads))
|
||||||
|
else:
|
||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -326,6 +335,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if skip_output_reshape:
|
||||||
|
r1 = (
|
||||||
|
r1.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
r1 = (
|
r1 = (
|
||||||
r1.unsqueeze(0)
|
r1.unsqueeze(0)
|
||||||
.reshape(b, heads, -1, dim_head)
|
.reshape(b, heads, -1, dim_head)
|
||||||
@ -342,7 +357,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@ -395,6 +410,9 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
|
if skip_output_reshape:
|
||||||
|
out = out.permute(0, 2, 1, 3)
|
||||||
|
else:
|
||||||
out = (
|
out = (
|
||||||
out.reshape(b, -1, heads * dim_head)
|
out.reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
@ -408,7 +426,7 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@ -429,6 +447,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout="HND"
|
tensor_layout="HND"
|
||||||
@ -473,9 +492,13 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
|
|
||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if skip_output_reshape:
|
||||||
|
out = out.transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
@ -33,6 +33,7 @@ import comfy.ldm.audio.embedders
|
|||||||
import comfy.ldm.flux.model
|
import comfy.ldm.flux.model
|
||||||
import comfy.ldm.lightricks.model
|
import comfy.ldm.lightricks.model
|
||||||
import comfy.ldm.hunyuan_video.model
|
import comfy.ldm.hunyuan_video.model
|
||||||
|
import comfy.ldm.cosmos.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -856,3 +857,19 @@ class HunyuanVideo(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class CosmosVideo(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EDM, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||||
|
return out
|
||||||
|
@ -239,6 +239,50 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["micro_condition"] = False
|
dit_config["micro_condition"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "cosmos"
|
||||||
|
dit_config["max_img_h"] = 240
|
||||||
|
dit_config["max_img_w"] = 240
|
||||||
|
dit_config["max_frames"] = 128
|
||||||
|
dit_config["in_channels"] = 16
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["patch_spatial"] = 2
|
||||||
|
dit_config["patch_temporal"] = 1
|
||||||
|
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["block_config"] = "FA-CA-MLP"
|
||||||
|
dit_config["concat_padding_mask"] = True
|
||||||
|
dit_config["pos_emb_cls"] = "rope3d"
|
||||||
|
dit_config["pos_emb_learnable"] = False
|
||||||
|
dit_config["pos_emb_interpolation"] = "crop"
|
||||||
|
dit_config["block_x_format"] = "THWBD"
|
||||||
|
dit_config["affline_emb_norm"] = True
|
||||||
|
dit_config["use_adaln_lora"] = True
|
||||||
|
dit_config["adaln_lora_dim"] = 256
|
||||||
|
|
||||||
|
if dit_config["model_channels"] == 4096:
|
||||||
|
# 7B
|
||||||
|
dit_config["num_blocks"] = 28
|
||||||
|
dit_config["num_heads"] = 32
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
|
else: # 5120
|
||||||
|
# 14B
|
||||||
|
dit_config["num_blocks"] = 36
|
||||||
|
dit_config["num_heads"] = 40
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = True
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_t_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -393,6 +437,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
"model.model.", #audio models
|
"model.model.", #audio models
|
||||||
|
"net.", #cosmos
|
||||||
]
|
]
|
||||||
counts = {k: 0 for k in candidates}
|
counts = {k: 0 for k in candidates}
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
|
25
comfy/sd.py
25
comfy/sd.py
@ -11,6 +11,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
|||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
import comfy.ldm.genmo.vae.model
|
import comfy.ldm.genmo.vae.model
|
||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
|
import comfy.ldm.cosmos.vae
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ import comfy.text_encoders.long_clipl
|
|||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
import comfy.text_encoders.lt
|
import comfy.text_encoders.lt
|
||||||
import comfy.text_encoders.hunyuan_video
|
import comfy.text_encoders.hunyuan_video
|
||||||
|
import comfy.text_encoders.cosmos
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -376,6 +378,19 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||||
|
self.upscale_index_formula = (8, 8, 8)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
|
||||||
|
self.downscale_index_formula = (8, 8, 8)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.latent_channels = 16
|
||||||
|
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
||||||
|
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
||||||
|
#TODO: these values are a bit off because this is not a standard VAE
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -641,6 +656,7 @@ class CLIPType(Enum):
|
|||||||
LTXV = 8
|
LTXV = 8
|
||||||
HUNYUAN_VIDEO = 9
|
HUNYUAN_VIDEO = 9
|
||||||
PIXART = 10
|
PIXART = 10
|
||||||
|
COSMOS = 11
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -658,6 +674,7 @@ class TEModel(Enum):
|
|||||||
T5_XL = 5
|
T5_XL = 5
|
||||||
T5_BASE = 6
|
T5_BASE = 6
|
||||||
LLAMA3_8 = 7
|
LLAMA3_8 = 7
|
||||||
|
T5_XXL_OLD = 8
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@ -672,6 +689,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.T5_XXL
|
return TEModel.T5_XXL
|
||||||
elif weight.shape[-1] == 2048:
|
elif weight.shape[-1] == 2048:
|
||||||
return TEModel.T5_XL
|
return TEModel.T5_XL
|
||||||
|
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||||
|
return TEModel.T5_XXL_OLD
|
||||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
@ -681,9 +700,10 @@ def detect_te_model(sd):
|
|||||||
|
|
||||||
def t5xxl_detect(clip_data):
|
def t5xxl_detect(clip_data):
|
||||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||||
|
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
|
||||||
|
|
||||||
for sd in clip_data:
|
for sd in clip_data:
|
||||||
if weight_name in sd:
|
if weight_name in sd or weight_name_old in sd:
|
||||||
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
@ -740,6 +760,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
|
elif te_model == TEModel.T5_XXL_OLD:
|
||||||
|
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
|
||||||
elif te_model == TEModel.T5_XL:
|
elif te_model == TEModel.T5_XL:
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
|
@ -14,6 +14,7 @@ import comfy.text_encoders.flux
|
|||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
import comfy.text_encoders.lt
|
import comfy.text_encoders.lt
|
||||||
import comfy.text_encoders.hunyuan_video
|
import comfy.text_encoders.hunyuan_video
|
||||||
|
import comfy.text_encoders.cosmos
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -823,6 +824,37 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
class Cosmos(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"sigma_data": 0.5,
|
||||||
|
"sigma_max": 80.0,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Cosmos1CV8x8x8
|
||||||
|
|
||||||
|
memory_usage_factor = 2.4 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosVideo(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, Cosmos]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
42
comfy/text_encoders/cosmos.py
Normal file
42
comfy/text_encoders/cosmos.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
import os
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||||
|
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||||
|
if t5xxl_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
|
||||||
|
|
||||||
|
|
||||||
|
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class CosmosTEModel_(CosmosT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return CosmosTEModel_
|
22
comfy/text_encoders/t5_old_config_xxl.json
Normal file
22
comfy/text_encoders/t5_old_config_xxl.json
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 65536,
|
||||||
|
"d_kv": 128,
|
||||||
|
"d_model": 1024,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "relu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": false,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 128,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
23
comfy_extras/nodes_cosmos.py
Normal file
23
comfy_extras/nodes_cosmos.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import nodes
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class EmptyCosmosLatentVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
|
||||||
|
def generate(self, width, height, length, batch_size=1):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||||
|
}
|
5
nodes.py
5
nodes.py
@ -912,7 +912,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -922,7 +922,7 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
if type == "stable_cascade":
|
if type == "stable_cascade":
|
||||||
@ -2225,6 +2225,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
|
"nodes_cosmos.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
Loading…
Reference in New Issue
Block a user