diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index dbe84b8b..5471763b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -421,7 +421,7 @@ class WanModel(torch.nn.Module): e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context - context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1)) + context = self.text_embedding(context) if clip_fea is not None and self.img_emb is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d98c9ad2..971ac8fa 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()}