Fix wan issues when prompt length is long.

This commit is contained in:
comfyanonymous 2025-02-26 20:34:02 -05:00
parent 8e69e2ddfd
commit 3ea3bc8546
2 changed files with 2 additions and 2 deletions

View File

@ -421,7 +421,7 @@ class WanModel(torch.nn.Module):
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1))
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim

View File

@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}