fixed improper padding

This commit is contained in:
BlenderNeko 2023-04-15 19:38:21 +02:00
parent da115bd78d
commit d0b1b6c6bf

View File

@ -247,6 +247,11 @@ class SD1Tokenizer:
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP Returned list has the dimensions NxM where M is the input size of CLIP
''' '''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
@ -277,30 +282,33 @@ class SD1Tokenizer:
#reshape token array to CLIP input size #reshape token array to CLIP input size
batched_tokens = [] batched_tokens = []
batch = [] batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch) batched_tokens.append(batch)
for i, t_group in enumerate(tokens): for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch #determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0: while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_tokens_per_section: if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_tokens_per_section - len(batch) remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large: if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:] t_group = t_group[remaining_length:]
#add end token and pad
else: else:
batch.extend([(self.end_token, 1.0, 0)] * remaining_length) batch.append((self.end_token, 1.0, 0))
batch = [] batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch) batched_tokens.append(batch)
else: else:
batch.extend([(t,w,i+1) for t,w in t_group]) batch.extend([(t,w,i+1) for t,w in t_group])
t_group = [] t_group = []
#fill last batch #fill last batch
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
#add start and end tokens
batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens]
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]