mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Allow specifying the padding token for the tokenizer.
This commit is contained in:
parent
ce649d61c0
commit
80c4590998
@ -364,7 +364,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||||
@ -380,6 +380,14 @@ class SDTokenizer:
|
|||||||
self.tokens_start = 0
|
self.tokens_start = 0
|
||||||
self.start_token = None
|
self.start_token = None
|
||||||
self.end_token = empty[0]
|
self.end_token = empty[0]
|
||||||
|
|
||||||
|
if pad_token is not None:
|
||||||
|
self.pad_token = pad_token
|
||||||
|
elif pad_with_end:
|
||||||
|
self.pad_token = self.end_token
|
||||||
|
else:
|
||||||
|
self.pad_token = 0
|
||||||
|
|
||||||
self.pad_with_end = pad_with_end
|
self.pad_with_end = pad_with_end
|
||||||
self.pad_to_max_length = pad_to_max_length
|
self.pad_to_max_length = pad_to_max_length
|
||||||
|
|
||||||
@ -412,10 +420,6 @@ class SDTokenizer:
|
|||||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||||
'''
|
'''
|
||||||
if self.pad_with_end:
|
|
||||||
pad_token = self.end_token
|
|
||||||
else:
|
|
||||||
pad_token = 0
|
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
@ -467,7 +471,7 @@ class SDTokenizer:
|
|||||||
else:
|
else:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||||
#start new batch
|
#start new batch
|
||||||
batch = []
|
batch = []
|
||||||
if self.start_token is not None:
|
if self.start_token is not None:
|
||||||
@ -480,9 +484,9 @@ class SDTokenizer:
|
|||||||
#fill last batch
|
#fill last batch
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
|
Loading…
Reference in New Issue
Block a user