diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 02e925c8..32612cf3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -247,6 +247,11 @@ class SD1Tokenizer: Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -277,30 +282,33 @@ class SD1Tokenizer: #reshape token array to CLIP input size batched_tokens = [] - batch = [] + batch = [(self.start_token, 1.0, 0)] batched_tokens.append(batch) for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch is_large = len(t_group) >= self.max_word_length + while len(t_group) > 0: - if len(t_group) + len(batch) > self.max_tokens_per_section: - remaining_length = self.max_tokens_per_section - len(batch) + if len(t_group) + len(batch) > self.max_length - 1: + remaining_length = self.max_length - len(batch) - 1 + #break word in two and add end token if is_large: batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + batch.append((self.end_token, 1.0, 0)) t_group = t_group[remaining_length:] + #add end token and pad else: - batch.extend([(self.end_token, 1.0, 0)] * remaining_length) - batch = [] + batch.append((self.end_token, 1.0, 0)) + batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + #start new batch + batch = [(self.start_token, 1.0, 0)] batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] #fill last batch - batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) - - #add start and end tokens - batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]