diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e4b0d34a..442a814c 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.flux.layers import EmbedND def modulate(x, scale): @@ -92,10 +93,9 @@ class JointAttention(nn.Module): and key tensor with rotary embeddings. """ - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) + t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float() + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x_in.shape).type_as(x_in) def forward( self, @@ -130,6 +130,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) + xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) @@ -480,7 +481,8 @@ class NextDiT(nn.Module): assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + # self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -550,7 +552,7 @@ class NextDiT(nn.Module): position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids) + freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) # build freqs_cis for cap and image individually cap_freqs_cis_shape = list(freqs_cis.shape)