mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Properly tokenize the template for hunyuan video.
This commit is contained in:
parent
cd6f615038
commit
ca457f7ba1
@ -10,6 +10,7 @@ import comfy.clip_model
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import numbers
|
import numbers
|
||||||
|
import re
|
||||||
|
|
||||||
def gen_empty_tokens(special_tokens, length):
|
def gen_empty_tokens(special_tokens, length):
|
||||||
start_token = special_tokens.get("start", None)
|
start_token = special_tokens.get("start", None)
|
||||||
@ -429,13 +430,14 @@ class SDTokenizer:
|
|||||||
self.end_token = None
|
self.end_token = None
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
|
self.tokenizer_adds_end_token = has_end_token
|
||||||
if has_start_token:
|
if has_start_token:
|
||||||
self.tokens_start = 1
|
self.tokens_start = 1
|
||||||
self.start_token = empty[0]
|
self.start_token = empty[0]
|
||||||
if has_end_token:
|
if end_token is not None:
|
||||||
if end_token is not None:
|
self.end_token = end_token
|
||||||
self.end_token = end_token
|
else:
|
||||||
else:
|
if has_end_token:
|
||||||
self.end_token = empty[1]
|
self.end_token = empty[1]
|
||||||
else:
|
else:
|
||||||
self.tokens_start = 0
|
self.tokens_start = 0
|
||||||
@ -468,7 +470,7 @@ class SDTokenizer:
|
|||||||
Takes a potential embedding name and tries to retrieve it.
|
Takes a potential embedding name and tries to retrieve it.
|
||||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||||
'''
|
'''
|
||||||
split_embed = embedding_name.split(' ')
|
split_embed = embedding_name.split()
|
||||||
embedding_name = split_embed[0]
|
embedding_name = split_embed[0]
|
||||||
leftover = ' '.join(split_embed[1:])
|
leftover = ' '.join(split_embed[1:])
|
||||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||||
@ -491,18 +493,18 @@ class SDTokenizer:
|
|||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
|
|
||||||
#tokenize words
|
# tokenize words
|
||||||
tokens = []
|
tokens = []
|
||||||
for weighted_segment, weight in parsed_weights:
|
for weighted_segment, weight in parsed_weights:
|
||||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ")
|
to_tokenize = unescape_important(weighted_segment)
|
||||||
split = to_tokenize.split(' {}'.format(self.embedding_identifier))
|
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
||||||
to_tokenize = [split[0]]
|
to_tokenize = [split[0]]
|
||||||
for i in range(1, len(split)):
|
for i in range(1, len(split)):
|
||||||
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||||
|
|
||||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||||
for word in to_tokenize:
|
for word in to_tokenize:
|
||||||
#if we find an embedding, deal with the embedding
|
# if we find an embedding, deal with the embedding
|
||||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||||
embed, leftover = self._try_get_embedding(embedding_name)
|
embed, leftover = self._try_get_embedding(embedding_name)
|
||||||
@ -519,7 +521,7 @@ class SDTokenizer:
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
end = 999999999999
|
end = 999999999999
|
||||||
if self.end_token is not None:
|
if self.tokenizer_adds_end_token:
|
||||||
end = -1
|
end = -1
|
||||||
#parse word
|
#parse word
|
||||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
|
||||||
|
@ -22,7 +22,7 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
||||||
|
|
||||||
class LLAMAModel(sd1_clip.SDClipModel):
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
@ -38,9 +38,7 @@ class HunyuanVideoTokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>
|
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
|
||||||
|
|
||||||
Describe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>""" # 93 tokens
|
|
||||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@ -86,6 +84,9 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
if v[0] == 128007: # <|end_header_id|>
|
if v[0] == 128007: # <|end_header_id|>
|
||||||
template_end = i
|
template_end = i
|
||||||
|
|
||||||
|
if llama_out.shape[1] > (template_end + 2):
|
||||||
|
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
|
||||||
|
template_end += 2
|
||||||
llama_out = llama_out[:, template_end:]
|
llama_out = llama_out[:, template_end:]
|
||||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
||||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||||
|
Loading…
Reference in New Issue
Block a user