mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Refactor some of the TE detection code.
This commit is contained in:
parent
b4626ab93e
commit
d1cdf51e1b
39
comfy/sd.py
39
comfy/sd.py
@ -406,6 +406,32 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class TEModel(Enum):
|
||||||
|
CLIP_L = 1
|
||||||
|
CLIP_H = 2
|
||||||
|
CLIP_G = 3
|
||||||
|
T5_XXL = 4
|
||||||
|
T5_XL = 5
|
||||||
|
T5_BASE = 6
|
||||||
|
|
||||||
|
def detect_te_model(sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_G
|
||||||
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_H
|
||||||
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_L
|
||||||
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
|
if weight.shape[-1] == 4096:
|
||||||
|
return TEModel.T5_XXL
|
||||||
|
elif weight.shape[-1] == 2048:
|
||||||
|
return TEModel.T5_XL
|
||||||
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
|
return TEModel.T5_BASE
|
||||||
|
return None
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
@ -421,30 +447,29 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if len(clip_data) == 1:
|
if len(clip_data) == 1:
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
te_model = detect_te_model(clip_data[0])
|
||||||
|
if te_model == TEModel.CLIP_G:
|
||||||
if clip_type == CLIPType.STABLE_CASCADE:
|
if clip_type == CLIPType.STABLE_CASCADE:
|
||||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif te_model == TEModel.CLIP_H:
|
||||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XXL:
|
||||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
dtype_t5 = weight.dtype
|
dtype_t5 = weight.dtype
|
||||||
if weight.shape[-1] == 4096:
|
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif weight.shape[-1] == 2048:
|
elif te_model == TEModel.T5_XL:
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_BASE:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
|
Loading…
Reference in New Issue
Block a user