Make wan work with all latent resolutions.

Cleanup some code.
This commit is contained in:
comfyanonymous 2025-02-25 19:56:04 -05:00
parent ea0f939df3
commit 9a66bb972d

View File

@ -9,6 +9,7 @@ from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.ldm.common_dit
def sinusoidal_embedding_1d(dim, position):
# preprocess
@ -67,12 +68,10 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
def forward(self, x, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
@ -100,12 +99,11 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
# compute query, key, value
q = self.norm_q(self.q(x))
@ -134,12 +132,11 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img = context[:, :257]
context = context[:, 257:]
@ -210,18 +207,13 @@ class WanAttentionBlock(nn.Module):
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
@ -231,19 +223,19 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
self.norm1(x) * (1 + e[1]) + e[0],
freqs)
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
def cross_attn_ffn(x, context, e):
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
x = cross_attn_ffn(x, context, e)
return x
@ -408,7 +400,6 @@ class WanModel(torch.nn.Module):
x,
t,
context,
seq_len=None,
clip_fea=None,
y=None,
freqs=None,
@ -417,12 +408,12 @@ class WanModel(torch.nn.Module):
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
x (Tensor):
List of input video tensors with shape [B, C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
List of text embeddings each with shape [B, L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
@ -438,22 +429,12 @@ class WanModel(torch.nn.Module):
assert clip_fea is not None and y is not None
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
x = torch.cat([x, y], dim=0)
# embeddings
x = [self.patch_embedding(u) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
if seq_len is not None:
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
elif len(x) == 1:
x = x[0]
x = self.patch_embedding(x)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
@ -461,13 +442,7 @@ class WanModel(torch.nn.Module):
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context_lens = None
context = self.text_embedding(
torch.cat([
torch.cat(
[u, u.new_zeros(u.size(0), self.text_len - u.size(1), u.size(2))], dim=1)
for u in context
], dim=0))
context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
@ -476,11 +451,8 @@ class WanModel(torch.nn.Module):
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context=context,
context_lens=context_lens)
context=context)
for block in self.blocks:
x = block(x, **kwargs)
@ -495,6 +467,7 @@ class WanModel(torch.nn.Module):
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -506,7 +479,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0]
return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""
@ -521,14 +494,13 @@ class WanModel(torch.nn.Module):
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
u = x
b = u.shape[0]
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return u