align behavior with old tokenize function

This commit is contained in:
BlenderNeko 2023-04-14 21:02:45 +02:00
parent 44fe868b66
commit 752f7a162b

View File

@ -226,12 +226,11 @@ class SD1Tokenizer:
self.max_word_length = 8 self.max_word_length = 8
self.embedding_identifier = "embedding:" self.embedding_identifier = "embedding:"
def _try_get_embedding(self, name:str): def _try_get_embedding(self, embedding_name:str):
''' '''
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
''' '''
embedding_name = name[len(self.embedding_identifier):].strip('\n')
embed = load_embed(embedding_name, self.embedding_directory) embed = load_embed(embedding_name, self.embedding_directory)
if embed is None: if embed is None:
stripped = embedding_name.strip(',') stripped = embedding_name.strip(',')
@ -259,9 +258,10 @@ class SD1Tokenizer:
for word in to_tokenize: for word in to_tokenize:
#if we find an embedding, deal with the embedding #if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embed, leftover = self._try_get_embedding(word) embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None: if embed is None:
print(f"warning, embedding:{word} does not exist, ignoring") print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else: else:
if len(embed.shape) == 1: if len(embed.shape) == 1:
tokens.append([(embed, weight)]) tokens.append([(embed, weight)])
@ -280,21 +280,21 @@ class SD1Tokenizer:
batch = [] batch = []
batched_tokens.append(batch) batched_tokens.append(batch)
for i, t_group in enumerate(tokens): for i, t_group in enumerate(tokens):
#start a new batch if there is not enough room #determine if we're going to try and keep the tokens in a single batch
if len(t_group) + len(batch) > self.max_tokens_per_section: is_large = len(t_group) >= self.max_word_length
remaining_length = self.max_tokens_per_section - len(batch) while len(t_group) > 0:
#fill remaining space depending on length of tokens if len(t_group) + len(batch) > self.max_tokens_per_section:
if len(t_group) > self.max_word_length: remaining_length = self.max_tokens_per_section - len(batch)
#put part of group of tokens in the batch if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
t_group = t_group[remaining_length:] t_group = t_group[remaining_length:]
else:
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
batch = []
batched_tokens.append(batch)
else: else:
#filler tokens batch.extend([(t,w,i+1) for t,w in t_group])
batch.extend([(self.end_token, 1.0, 0)] * remaining_length) t_group = []
batch = []
batched_tokens.append(batch)
#put current group of tokens in the batch
batch.extend([(t,w,i+1) for t,w in t_group])
#fill last batch #fill last batch
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))