diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 15119297..4f2315ac 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -9,6 +9,7 @@ from einops import repeat from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope +import comfy.ldm.common_dit def sinusoidal_embedding_1d(dim, position): # preprocess @@ -67,12 +68,10 @@ class WanSelfAttention(nn.Module): self.norm_q = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, seq_lens, grid_sizes, freqs): + def forward(self, x, freqs): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] - seq_lens(Tensor): Shape [B] - grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim @@ -100,12 +99,11 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, context_lens): + def forward(self, x, context): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] """ # compute query, key, value q = self.norm_q(self.q(x)) @@ -134,12 +132,11 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context, context_lens): + def forward(self, x, context): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] """ context_img = context[:, :257] context = context[:, 257:] @@ -210,18 +207,13 @@ class WanAttentionBlock(nn.Module): self, x, e, - seq_lens, - grid_sizes, freqs, context, - context_lens, ): r""" Args: x(Tensor): Shape [B, L, C] e(Tensor): Shape [B, 6, C] - seq_lens(Tensor): Shape [B], length of each sequence in batch - grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ # assert e.dtype == torch.float32 @@ -231,19 +223,19 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( - self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, + self.norm1(x) * (1 + e[1]) + e[0], freqs) x = x + y * e[2] # cross-attention & ffn function - def cross_attn_ffn(x, context, context_lens, e): - x = x + self.cross_attn(self.norm3(x), context, context_lens) + def cross_attn_ffn(x, context, e): + x = x + self.cross_attn(self.norm3(x), context) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x - x = cross_attn_ffn(x, context, context_lens, e) + x = cross_attn_ffn(x, context, e) return x @@ -408,7 +400,6 @@ class WanModel(torch.nn.Module): x, t, context, - seq_len=None, clip_fea=None, y=None, freqs=None, @@ -417,12 +408,12 @@ class WanModel(torch.nn.Module): Forward pass through the diffusion model Args: - x (List[Tensor]): - List of input video tensors, each with shape [C_in, F, H, W] + x (Tensor): + List of input video tensors with shape [B, C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]): - List of text embeddings each with shape [L, C] + List of text embeddings each with shape [B, L, C] seq_len (`int`): Maximum sequence length for positional encoding clip_fea (Tensor, *optional*): @@ -438,22 +429,12 @@ class WanModel(torch.nn.Module): assert clip_fea is not None and y is not None if y is not None: - x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + x = torch.cat([x, y], dim=0) # embeddings - x = [self.patch_embedding(u) for u in x] - grid_sizes = torch.stack( - [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) - x = [u.flatten(2).transpose(1, 2) for u in x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - if seq_len is not None: - assert seq_lens.max() <= seq_len - x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in x - ]) - elif len(x) == 1: - x = x[0] + x = self.patch_embedding(x) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) # time embeddings e = self.time_embedding( @@ -461,13 +442,7 @@ class WanModel(torch.nn.Module): e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context - context_lens = None - context = self.text_embedding( - torch.cat([ - torch.cat( - [u, u.new_zeros(u.size(0), self.text_len - u.size(1), u.size(2))], dim=1) - for u in context - ], dim=0)) + context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1)) if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim @@ -476,11 +451,8 @@ class WanModel(torch.nn.Module): # arguments kwargs = dict( e=e0, - seq_lens=seq_lens, - grid_sizes=grid_sizes, freqs=freqs, - context=context, - context_lens=context_lens) + context=context) for block in self.blocks: x = block(x, **kwargs) @@ -495,6 +467,7 @@ class WanModel(torch.nn.Module): def forward(self, x, timestep, context, y=None, image=None, **kwargs): bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -506,7 +479,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0] + return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" @@ -521,14 +494,13 @@ class WanModel(torch.nn.Module): Returns: List[Tensor]: - Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8] """ c = self.out_dim - out = [] - for u, v in zip(x, grid_sizes.tolist()): - u = u[:math.prod(v)].view(*v, *self.patch_size, c) - u = torch.einsum('fhwpqrc->cfphqwr', u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) - out.append(u) - return out + u = x + b = u.shape[0] + u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) + u = torch.einsum('bfhwpqrc->bcfphqwr', u) + u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + return u