mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
add unique ID per word/embedding for tokenizer
This commit is contained in:
parent
f5f701388b
commit
8489cba140
@ -224,60 +224,85 @@ class SD1Tokenizer:
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.embedding_directory = embedding_directory
|
||||
self.max_word_length = 8
|
||||
self.embedding_identifier = "embedding:"
|
||||
|
||||
def tokenize_with_weights(self, text):
|
||||
def _try_get_embedding(self, name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
embedding_name = name[len(self.embedding_identifier):].strip('\n')
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory)
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str):
|
||||
'''
|
||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||
'''
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
#tokenize words
|
||||
tokens = []
|
||||
for t in parsed_weights:
|
||||
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
|
||||
while len(to_tokenize) > 0:
|
||||
word = to_tokenize.pop(0)
|
||||
temp_tokens = []
|
||||
embedding_identifier = "embedding:"
|
||||
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(embedding_identifier):].strip('\n')
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
#if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embed, leftover = self._try_get_embedding(word)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory)
|
||||
if embed is not None:
|
||||
to_tokenize.insert(0, embedding_name[len(stripped):])
|
||||
|
||||
if embed is not None:
|
||||
if len(embed.shape) == 1:
|
||||
temp_tokens += [(embed, t[1])]
|
||||
else:
|
||||
for x in range(embed.shape[0]):
|
||||
temp_tokens += [(embed[x], t[1])]
|
||||
print(f"warning, embedding:{word} does not exist, ignoring")
|
||||
else:
|
||||
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
|
||||
elif len(word) > 0:
|
||||
tt = self.tokenizer(word)["input_ids"][1:-1]
|
||||
for x in tt:
|
||||
temp_tokens += [(x, t[1])]
|
||||
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
|
||||
if len(embed.shape) == 1:
|
||||
tokens.append([(embed, weight)])
|
||||
else:
|
||||
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
|
||||
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
||||
if leftover != "":
|
||||
word = leftover
|
||||
else:
|
||||
continue
|
||||
#parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
|
||||
|
||||
#reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = []
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
#start a new batch if there is not enough room
|
||||
if len(t_group) + len(batch) > self.max_tokens_per_section:
|
||||
remaining_length = self.max_tokens_per_section - len(batch)
|
||||
#fill remaining space depending on length of tokens
|
||||
if len(t_group) > self.max_word_length:
|
||||
#put part of group of tokens in the batch
|
||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||
t_group = t_group[remaining_length:]
|
||||
else:
|
||||
#filler tokens
|
||||
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
|
||||
batch = []
|
||||
batched_tokens.append(batch)
|
||||
#put current group of tokens in the batch
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
|
||||
#fill last batch
|
||||
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))
|
||||
|
||||
#add start and end tokens
|
||||
batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens]
|
||||
return batched_tokens
|
||||
|
||||
#try not to split words in different sections
|
||||
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
|
||||
for x in range(tokens_left):
|
||||
tokens += [(self.end_token, 1.0)]
|
||||
tokens += temp_tokens
|
||||
|
||||
out_tokens = []
|
||||
for x in range(0, len(tokens), self.max_tokens_per_section):
|
||||
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
|
||||
o_token += [(self.end_token, 1.0)]
|
||||
if self.pad_with_end:
|
||||
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
||||
else:
|
||||
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
|
||||
|
||||
out_tokens += [o_token]
|
||||
|
||||
return out_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
Loading…
Reference in New Issue
Block a user