diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3b812b44..ed3dc229 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -364,7 +364,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) @@ -380,6 +380,14 @@ class SDTokenizer: self.tokens_start = 0 self.start_token = None self.end_token = empty[0] + + if pad_token is not None: + self.pad_token = pad_token + elif pad_with_end: + self.pad_token = self.end_token + else: + self.pad_token = 0 + self.pad_with_end = pad_with_end self.pad_to_max_length = pad_to_max_length @@ -412,10 +420,6 @@ class SDTokenizer: Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -467,7 +471,7 @@ class SDTokenizer: else: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) #start new batch batch = [] if self.start_token is not None: @@ -480,9 +484,9 @@ class SDTokenizer: #fill last batch batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) + batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) + batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]