Properly tokenize the template for hunyuan video.

This commit is contained in:
comfyanonymous 2024-12-17 16:18:35 -05:00
parent cd6f615038
commit ca457f7ba1
2 changed files with 17 additions and 14 deletions

View File

@ -10,6 +10,7 @@ import comfy.clip_model
import json import json
import logging import logging
import numbers import numbers
import re
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
@ -429,13 +430,14 @@ class SDTokenizer:
self.end_token = None self.end_token = None
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
if has_start_token: if has_start_token:
self.tokens_start = 1 self.tokens_start = 1
self.start_token = empty[0] self.start_token = empty[0]
if has_end_token: if end_token is not None:
if end_token is not None: self.end_token = end_token
self.end_token = end_token else:
else: if has_end_token:
self.end_token = empty[1] self.end_token = empty[1]
else: else:
self.tokens_start = 0 self.tokens_start = 0
@ -468,7 +470,7 @@ class SDTokenizer:
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
''' '''
split_embed = embedding_name.split(' ') split_embed = embedding_name.split()
embedding_name = split_embed[0] embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:]) leftover = ' '.join(split_embed[1:])
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
@ -491,18 +493,18 @@ class SDTokenizer:
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
#tokenize words # tokenize words
tokens = [] tokens = []
for weighted_segment, weight in parsed_weights: for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ") to_tokenize = unescape_important(weighted_segment)
split = to_tokenize.split(' {}'.format(self.embedding_identifier)) split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
to_tokenize = [split[0]] to_tokenize = [split[0]]
for i in range(1, len(split)): for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i])) to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
to_tokenize = [x for x in to_tokenize if x != ""] to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize: for word in to_tokenize:
#if we find an embedding, deal with the embedding # if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n') embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name) embed, leftover = self._try_get_embedding(embedding_name)
@ -519,7 +521,7 @@ class SDTokenizer:
else: else:
continue continue
end = 999999999999 end = 999999999999
if self.end_token is not None: if self.tokenizer_adds_end_token:
end = -1 end = -1
#parse word #parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]]) tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])

View File

@ -22,7 +22,7 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel): class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
@ -38,9 +38,7 @@ class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|> self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
Describe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>""" # 93 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -86,6 +84,9 @@ class HunyuanVideoClipModel(torch.nn.Module):
if v[0] == 128007: # <|end_header_id|> if v[0] == 128007: # <|end_header_id|>
template_end = i template_end = i
if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
template_end += 2
llama_out = llama_out[:, template_end:] llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:] llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):