mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
208 lines
8.7 KiB
Python
208 lines
8.7 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from torch import nn
|
|
import math
|
|
|
|
|
|
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
|
"""
|
|
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor to normalize.
|
|
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
|
eps (float, optional): A small constant to ensure numerical stability during division.
|
|
|
|
Returns:
|
|
torch.Tensor: The normalized tensor.
|
|
"""
|
|
if dim is None:
|
|
dim = list(range(1, x.ndim))
|
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
|
return x / norm.to(x.dtype)
|
|
|
|
|
|
class VideoPositionEmb(nn.Module):
|
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
|
"""
|
|
It delegates the embedding generation to generate_embeddings function.
|
|
"""
|
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
|
|
|
|
return embeddings
|
|
|
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
|
raise NotImplementedError
|
|
|
|
|
|
class VideoRopePosition3DEmb(VideoPositionEmb):
|
|
def __init__(
|
|
self,
|
|
*, # enforce keyword arguments
|
|
head_dim: int,
|
|
len_h: int,
|
|
len_w: int,
|
|
len_t: int,
|
|
base_fps: int = 24,
|
|
h_extrapolation_ratio: float = 1.0,
|
|
w_extrapolation_ratio: float = 1.0,
|
|
t_extrapolation_ratio: float = 1.0,
|
|
device=None,
|
|
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
|
):
|
|
del kwargs
|
|
super().__init__()
|
|
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
|
self.base_fps = base_fps
|
|
self.max_h = len_h
|
|
self.max_w = len_w
|
|
|
|
dim = head_dim
|
|
dim_h = dim // 6 * 2
|
|
dim_w = dim_h
|
|
dim_t = dim - 2 * dim_h
|
|
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
|
self.register_buffer(
|
|
"dim_spatial_range",
|
|
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
|
persistent=False,
|
|
)
|
|
self.register_buffer(
|
|
"dim_temporal_range",
|
|
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
|
persistent=False,
|
|
)
|
|
|
|
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
|
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
|
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
|
|
|
def generate_embeddings(
|
|
self,
|
|
B_T_H_W_C: torch.Size,
|
|
fps: Optional[torch.Tensor] = None,
|
|
h_ntk_factor: Optional[float] = None,
|
|
w_ntk_factor: Optional[float] = None,
|
|
t_ntk_factor: Optional[float] = None,
|
|
device=None,
|
|
):
|
|
"""
|
|
Generate embeddings for the given input size.
|
|
|
|
Args:
|
|
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
|
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
|
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
|
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
|
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
|
|
|
Returns:
|
|
Not specified in the original code snippet.
|
|
"""
|
|
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
|
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
|
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
|
|
|
h_theta = 10000.0 * h_ntk_factor
|
|
w_theta = 10000.0 * w_ntk_factor
|
|
t_theta = 10000.0 * t_ntk_factor
|
|
|
|
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
|
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
|
|
|
B, T, H, W, _ = B_T_H_W_C
|
|
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
|
assert (
|
|
uniform_fps or B == 1 or T == 1
|
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
|
assert (
|
|
H <= self.max_h and W <= self.max_w
|
|
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
|
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
|
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
|
|
|
# apply sequence scaling in temporal dimension
|
|
if fps is None: # image case
|
|
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
|
else:
|
|
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
|
|
|
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
|
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
|
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
|
|
|
em_T_H_W_D = torch.cat(
|
|
[
|
|
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
|
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
|
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
|
]
|
|
, dim=-2,
|
|
)
|
|
|
|
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
|
|
|
|
|
class LearnablePosEmbAxis(VideoPositionEmb):
|
|
def __init__(
|
|
self,
|
|
*, # enforce keyword arguments
|
|
interpolation: str,
|
|
model_channels: int,
|
|
len_h: int,
|
|
len_w: int,
|
|
len_t: int,
|
|
device=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
|
"""
|
|
del kwargs # unused
|
|
super().__init__()
|
|
self.interpolation = interpolation
|
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
|
|
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
|
|
|
|
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
|
B, T, H, W, _ = B_T_H_W_C
|
|
if self.interpolation == "crop":
|
|
emb_h_H = self.pos_emb_h[:H].to(device=device)
|
|
emb_w_W = self.pos_emb_w[:W].to(device=device)
|
|
emb_t_T = self.pos_emb_t[:T].to(device=device)
|
|
emb = (
|
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
|
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
|
)
|
|
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
|
else:
|
|
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
|
|
|
return normalize(emb, dim=-1, eps=1e-6)
|