mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-05-10 14:36:07 +00:00
Don't hardcode length of context_img in wan code.
This commit is contained in:
parent
1fc00ba4b6
commit
0d720e4367
@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context, context_img_len):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
context(Tensor): Shape [B, L2, C]
|
context(Tensor): Shape [B, L2, C]
|
||||||
"""
|
"""
|
||||||
context_img = context[:, :257]
|
context_img = context[:, :context_img_len]
|
||||||
context = context[:, 257:]
|
context = context[:, context_img_len:]
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
e,
|
e,
|
||||||
freqs,
|
freqs,
|
||||||
context,
|
context,
|
||||||
|
context_img_len=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x = x + y * e[2]
|
x = x + y * e[2]
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||||
x = x + y * e[5]
|
x = x + y * e[5]
|
||||||
return x
|
return x
|
||||||
@ -420,9 +421,12 @@ class WanModel(torch.nn.Module):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
if clip_fea is not None and self.img_emb is not None:
|
context_img_len = None
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
if clip_fea is not None:
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -430,12 +434,12 @@ class WanModel(torch.nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
Loading…
Reference in New Issue
Block a user