From 44db9785310ab6ce7214ba7f04060bdbe572d284 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 10 Dec 2024 23:07:26 -0500 Subject: [PATCH] Fix a few things in text enc code for models with no eos token. --- comfy/sd1_clip.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 887874a0..8ee69eaf 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -199,11 +199,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): attention_mask = None if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: attention_mask = torch.zeros_like(tokens) - end_token = self.special_tokens.get("end", -1) + end_token = self.special_tokens.get("end", None) + if end_token is None: + cmp_token = self.special_tokens.get("pad", -1) + else: + cmp_token = end_token + for x in range(attention_mask.shape[0]): for y in range(attention_mask.shape[1]): attention_mask[x, y] = 1 - if tokens[x, y] == end_token: + if tokens[x, y] == cmp_token: + if end_token is None: + attention_mask[x, y] = 0 break attention_mask_model = None @@ -522,10 +529,14 @@ class SDTokenizer: for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch is_large = len(t_group) >= self.max_word_length + if self.end_token is not None: + has_end_token = 1 + else: + has_end_token = 0 while len(t_group) > 0: - if len(t_group) + len(batch) > self.max_length - 1: - remaining_length = self.max_length - len(batch) - 1 + if len(t_group) + len(batch) > self.max_length - has_end_token: + remaining_length = self.max_length - len(batch) - has_end_token #break word in two and add end token if is_large: batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])