From 3e8155f7a3d7601838bbc82a8ccf550343bbb132 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 10:32:21 -0400 Subject: [PATCH] More flexible long clip support. Add clip g long clip support. Text encoder refactor. Support llama models with different vocab sizes. --- comfy/sd1_clip.py | 23 ++++++++++++++--- comfy/sdxl_clip.py | 14 +++++------ comfy/text_encoders/aura_t5.py | 2 +- comfy/text_encoders/cosmos.py | 2 +- comfy/text_encoders/flux.py | 10 +++----- comfy/text_encoders/genmo.py | 2 +- comfy/text_encoders/hunyuan_video.py | 22 ++++++++++------- comfy/text_encoders/hydit.py | 8 +++--- comfy/text_encoders/llama.py | 14 ++++++++++- comfy/text_encoders/long_clipl.py | 37 ++++++++++++++-------------- comfy/text_encoders/lt.py | 2 +- comfy/text_encoders/lumina2.py | 2 +- comfy/text_encoders/pixart_t5.py | 2 +- comfy/text_encoders/sa_t5.py | 2 +- comfy/text_encoders/sd2_clip.py | 2 +- comfy/text_encoders/sd3_clip.py | 15 ++++++----- comfy/text_encoders/wan.py | 2 +- 17 files changed, 95 insertions(+), 66 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index be21ec18..2ca5ed9b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): LAYERS = [ "last", "pooled", - "hidden" + "hidden", + "all" ] def __init__(self, device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, @@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + if "model_name" not in model_options: + model_options = {**model_options, "model_name": "clip_l"} if isinstance(textmodel_json_config, dict): config = textmodel_json_config @@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) + te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {}) + for k, v in te_model_options.items(): + config[k] = v + operations = model_options.get("custom_operations", None) scaled_fp8 = None @@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: + if self.layer == "all": + pass + elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + if self.layer == "all": + intermediate_output = "all" + else: + intermediate_output = self.layer_idx + + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() @@ -447,7 +461,7 @@ class SDTokenizer: if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) - self.max_length = max_length + self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.min_length = min_length self.end_token = None @@ -645,6 +659,7 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) clip_model = model_options.get("{}_class".format(self.clip), clip_model) + model_options = {**model_options, "model_name": self.clip} setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 5b7c8a41..ea7f5d10 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) @@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class SDXLTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -41,8 +41,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) @@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index e9ad45a7..cf4252ee 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel): class PT5XlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data) class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 5441c895..a1adb524 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data) class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index a12995ec..0666dde7 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -9,14 +9,13 @@ import os class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class FluxTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 45987a48..9dcf190a 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index dbb259e5..33ac2249 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -21,26 +21,31 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data) class LLAMAModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) if llama_scaled_fp8 is not None: model_options = model_options.copy() model_options["scaled_fp8"] = llama_scaled_fp8 - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + textmodel_json_config = {} + vocab_size = model_options.get("vocab_size", None) + if vocab_size is not None: + textmodel_json_config["vocab_size"] = vocab_size + + model_options = {**model_options, "model_name": "llama"} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens - self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) + self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} @@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) self.dtypes = set([dtype, dtype_llama]) diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 7da3e9fc..e7273f42 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -9,24 +9,26 @@ import torch class HyditBertModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") + model_options = {**model_options, "model_name": "hydit_clip"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data) class MT5XLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") + model_options = {**model_options, "model_name": "mt5xl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class MT5XLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): #tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model") tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -35,7 +37,7 @@ class HyditTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None) self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) - self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) + self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 58710b2b..34eb870e 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -268,11 +268,17 @@ class Llama2_(nn.Module): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output for i, layer in enumerate(self.layers): + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -283,6 +289,12 @@ class Llama2_(nn.Module): intermediate = x.clone() x = self.norm(x) + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + if intermediate is not None and final_layer_norm_intermediate: intermediate = self.norm(intermediate) diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index b81912cb..f9483b42 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,30 +1,29 @@ from comfy import sd1_clip import os -class LongClipTokenizer_(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - -class LongClipModel_(sd1_clip.SDClipModel): - def __init__(self, *args, **kwargs): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") - super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs) - -class LongClipTokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) - -class LongClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): - super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) def model_options_long_clip(sd, tokenizer_data, model_options): w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) + if w is None: + w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None) + else: + model_name = "clip_g" + if w is None: w = sd.get("text_model.embeddings.position_embedding.weight", None) - if w is not None and w.shape[0] == 248: + if w is not None: + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + model_name = "clip_g" + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + model_name = "clip_l" + else: + model_name = "clip_l" + + if w is not None: tokenizer_data = tokenizer_data.copy() model_options = model_options.copy() - tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_ - model_options["clip_l_class"] = LongClipModel_ + model_config = model_options.get("model_config", {}) + model_config["max_position_embeddings"] = w.shape[0] + model_options["{}_model_config".format(model_name)] = model_config + tokenizer_data["{}_max_length".format(model_name)] = w.shape[0] return tokenizer_data, model_options diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5c2ce583..48ea67e6 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,7 +6,7 @@ import comfy.text_encoders.genmo class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128? + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128? class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index a7b1d702..674461b7 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -6,7 +6,7 @@ import comfy.text_encoders.llama class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index d56d57f1..b8de6bc4 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 7778ce47..2803926a 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel): class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) class SAT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 31fc8986..700a23bf 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data) class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 3ad2ed93..1727998a 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel): model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module): super().__init__() self.dtypes = set() if clip_l: - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index 971ac8fa..d50fa4b2 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()}