diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index ec411972..3292bd2f 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -352,25 +352,6 @@ class FinalLayer(nn.Module): return x -class RopeEmbedder: - def __init__( - self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512) - ): - super().__init__() - self.theta = theta - self.axes_dims = axes_dims - self.axes_lens = axes_lens - self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - - def __call__(self, ids: torch.Tensor): - self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64) - result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) - return torch.cat(result, dim=-1) - - class NextDiT(nn.Module): """ Diffusion model with a Transformer backbone. @@ -481,7 +462,6 @@ class NextDiT(nn.Module): assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - # self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -609,7 +589,6 @@ class NextDiT(nn.Module): return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis - # def forward(self, x, t, cap_feats, cap_mask): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): t = 1.0 - timesteps @@ -638,37 +617,3 @@ class NextDiT(nn.Module): return -x - @staticmethod - def precompute_freqs_cis( - dim: List[int], - end: List[int], - theta: float = 10000.0, - ): - """ - Precompute the frequency tensor for complex exponentials (cis) with - given dimensions. - - This function calculates a frequency tensor with complex exponentials - using the given dimension 'dim' and the end index 'end'. The 'theta' - parameter scales the frequencies. The returned tensor contains complex - values in complex64 data type. - - Args: - dim (list): Dimension of the frequency tensor. - end (list): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. - Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex - exponentials. - """ - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis