mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
5cbb01bc2f
To use: "Load CLIP" node with t5xxl + type mochi "Load Diffusion Model" node with the mochi dit file. "Load VAE" with the mochi vae file. EmptyMochiLatentVideo node for the latent. euler + linear_quadratic in the KSampler node.
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
#original code from https://github.com/genmoai/models under apache 2.0 license
|
|
|
|
# Based on Llama3 Implementation.
|
|
import torch
|
|
|
|
|
|
def apply_rotary_emb_qk_real(
|
|
xqk: torch.Tensor,
|
|
freqs_cos: torch.Tensor,
|
|
freqs_sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
|
|
|
Args:
|
|
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
|
Can be either just query or just key, or both stacked along some batch or * dim.
|
|
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
|
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The input tensor with rotary embeddings applied.
|
|
"""
|
|
# Split the last dimension into even and odd parts
|
|
xqk_even = xqk[..., 0::2]
|
|
xqk_odd = xqk[..., 1::2]
|
|
|
|
# Apply rotation
|
|
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
|
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
|
|
|
# Interleave the results back into the original shape
|
|
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
|
return out
|