mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Support llama hunyuan video text encoder in scaled fp8 format.
This commit is contained in:
parent
f4cdedea62
commit
d6656b0c0c
10
comfy/sd.py
10
comfy/sd.py
@ -603,6 +603,14 @@ def t5xxl_detect(clip_data):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def llama_detect(clip_data):
|
||||||
|
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||||
|
|
||||||
|
for sd in clip_data:
|
||||||
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
@ -669,7 +677,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip() #TODO
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
|
@ -783,9 +783,9 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
# pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
# t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip()) #TODO
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
||||||
|
|
||||||
|
@ -6,6 +6,19 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def llama_detect(state_dict, prefix=""):
|
||||||
|
out = {}
|
||||||
|
t5_key = "{}model.norm.weight".format(prefix)
|
||||||
|
if t5_key in state_dict:
|
||||||
|
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||||
|
Loading…
Reference in New Issue
Block a user