mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Use regular numbers for rope in lumina model.
This commit is contained in:
parent
a57d635c5f
commit
60653004e5
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, scale):
|
def modulate(x, scale):
|
||||||
@ -92,10 +93,9 @@ class JointAttention(nn.Module):
|
|||||||
and key tensor with rotary embeddings.
|
and key tensor with rotary embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float()
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
return t_out.reshape(*x_in.shape).type_as(x_in)
|
||||||
return x_out.type_as(x_in)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -130,6 +130,7 @@ class JointAttention(nn.Module):
|
|||||||
|
|
||||||
xq = self.q_norm(xq)
|
xq = self.q_norm(xq)
|
||||||
xk = self.k_norm(xk)
|
xk = self.k_norm(xk)
|
||||||
|
|
||||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
@ -480,7 +481,8 @@ class NextDiT(nn.Module):
|
|||||||
assert (dim // n_heads) == sum(axes_dims)
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
self.axes_dims = axes_dims
|
self.axes_dims = axes_dims
|
||||||
self.axes_lens = axes_lens
|
self.axes_lens = axes_lens
|
||||||
self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
|
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
|
||||||
|
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
||||||
@ -550,7 +552,7 @@ class NextDiT(nn.Module):
|
|||||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||||
|
|
||||||
freqs_cis = self.rope_embedder(position_ids)
|
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
||||||
|
|
||||||
# build freqs_cis for cap and image individually
|
# build freqs_cis for cap and image individually
|
||||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user