diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9b5e5332..d64e73a8 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context): + def forward(self, x, context, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context): + def forward(self, x, context, context_img_len): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - context_img = context[:, :257] - context = context[:, 257:] + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] # compute query, key, value q = self.norm_q(self.q(x)) @@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, + context_img_len=None, ): r""" Args: @@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x @@ -420,9 +421,12 @@ class WanModel(torch.nn.Module): # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -430,12 +434,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) # head x = self.head(x, e)