mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Make SDTokenizer class work with more types of tokenizers.
This commit is contained in:
parent
a6c83b3cd0
commit
b3fcd64c6c
@ -343,17 +343,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.max_tokens_per_section = self.max_length - 2
|
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
|
if has_start_token:
|
||||||
|
self.tokens_start = 1
|
||||||
self.start_token = empty[0]
|
self.start_token = empty[0]
|
||||||
self.end_token = empty[1]
|
self.end_token = empty[1]
|
||||||
|
else:
|
||||||
|
self.tokens_start = 0
|
||||||
|
self.start_token = None
|
||||||
|
self.end_token = empty[0]
|
||||||
self.pad_with_end = pad_with_end
|
self.pad_with_end = pad_with_end
|
||||||
|
self.pad_to_max_length = pad_to_max_length
|
||||||
|
|
||||||
vocab = self.tokenizer.get_vocab()
|
vocab = self.tokenizer.get_vocab()
|
||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
self.embedding_directory = embedding_directory
|
self.embedding_directory = embedding_directory
|
||||||
@ -414,11 +421,13 @@ class SDTokenizer:
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
#parse word
|
#parse word
|
||||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||||||
|
|
||||||
#reshape token array to CLIP input size
|
#reshape token array to CLIP input size
|
||||||
batched_tokens = []
|
batched_tokens = []
|
||||||
batch = [(self.start_token, 1.0, 0)]
|
batch = []
|
||||||
|
if self.start_token is not None:
|
||||||
|
batch.append((self.start_token, 1.0, 0))
|
||||||
batched_tokens.append(batch)
|
batched_tokens.append(batch)
|
||||||
for i, t_group in enumerate(tokens):
|
for i, t_group in enumerate(tokens):
|
||||||
#determine if we're going to try and keep the tokens in a single batch
|
#determine if we're going to try and keep the tokens in a single batch
|
||||||
@ -435,16 +444,21 @@ class SDTokenizer:
|
|||||||
#add end token and pad
|
#add end token and pad
|
||||||
else:
|
else:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||||
#start new batch
|
#start new batch
|
||||||
batch = [(self.start_token, 1.0, 0)]
|
batch = []
|
||||||
|
if self.start_token is not None:
|
||||||
|
batch.append((self.start_token, 1.0, 0))
|
||||||
batched_tokens.append(batch)
|
batched_tokens.append(batch)
|
||||||
else:
|
else:
|
||||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||||
t_group = []
|
t_group = []
|
||||||
|
|
||||||
#fill last batch
|
#fill last batch
|
||||||
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
if self.pad_to_max_length:
|
||||||
|
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
|
Loading…
Reference in New Issue
Block a user