Don't hardcode length of context_img in wan code.

This commit is contained in:
comfyanonymous 2025-04-17 06:25:39 -04:00
parent 1fc00ba4b6
commit 0d720e4367

View File

@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention): class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context): def forward(self, x, context, **kwargs):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, ))) # self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context): def forward(self, x, context, context_img_len):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C] context(Tensor): Shape [B, L2, C]
""" """
context_img = context[:, :257] context_img = context[:, :context_img_len]
context = context[:, 257:] context = context[:, context_img_len:]
# compute query, key, value # compute query, key, value
q = self.norm_q(self.q(x)) q = self.norm_q(self.q(x))
@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
e, e,
freqs, freqs,
context, context,
context_img_len=None,
): ):
r""" r"""
Args: Args:
@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2] x = x + y * e[2]
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5] x = x + y * e[5]
return x return x
@ -420,9 +421,12 @@ class WanModel(torch.nn.Module):
# context # context
context = self.text_embedding(context) context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None: context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1) context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
@ -430,12 +434,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context) x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head # head
x = self.head(x, e) x = self.head(x, e)