mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 14:40:27 +00:00
Remove some useless code.
This commit is contained in:
parent
f1059b0b82
commit
14880e6dba
@ -352,25 +352,6 @@ class FinalLayer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
|
||||
):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
@ -481,7 +462,6 @@ class NextDiT(nn.Module):
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
@ -609,7 +589,6 @@ class NextDiT(nn.Module):
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
@ -638,37 +617,3 @@ class NextDiT(nn.Module):
|
||||
|
||||
return -x
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(
|
||||
dim: List[int],
|
||||
end: List[int],
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with
|
||||
given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials
|
||||
using the given dimension 'dim' and the end index 'end'. The 'theta'
|
||||
parameter scales the frequencies. The returned tensor contains complex
|
||||
values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (list): Dimension of the frequency tensor.
|
||||
end (list): End index for precomputing frequencies.
|
||||
theta (float, optional): Scaling factor for frequency computation.
|
||||
Defaults to 10000.0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex
|
||||
exponentials.
|
||||
"""
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
Loading…
Reference in New Issue
Block a user