mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add DualClipLoader to load clip models for SDXL.
Update LoadClip to load clip models for SDXL refiner.
This commit is contained in:
parent
b7933960bb
commit
20f579d91d
41
comfy/sd.py
41
comfy/sd.py
@ -19,6 +19,7 @@ from . import model_detection
|
||||
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
|
||||
def load_model_weights(model, sd):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
@ -524,7 +525,7 @@ class CLIP:
|
||||
return n
|
||||
|
||||
def load_from_state_dict(self, sd):
|
||||
self.cond_stage_model.transformer.load_state_dict(sd, strict=False)
|
||||
self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def add_patches(self, patches, strength=1.0):
|
||||
return self.patcher.add_patches(patches, strength)
|
||||
@ -555,6 +556,8 @@ class CLIP:
|
||||
tokens = self.tokenize(text)
|
||||
return self.encode_from_tokens(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||
@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
|
||||
return StyleModel(model)
|
||||
|
||||
|
||||
def load_clip(ckpt_path, embedding_directory=None):
|
||||
clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
def load_clip(ckpt_paths, embedding_directory=None):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(utils.load_torch_file(p, safe_load=True))
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
for i in range(len(clip_data)):
|
||||
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
|
||||
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)
|
||||
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
clip.load_from_state_dict(clip_data)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
print("clip missing:", m)
|
||||
|
||||
if len(u) > 0:
|
||||
print("clip unexpected:", u)
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
|
@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
|
@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_projection" in sd:
|
||||
self.text_projection[:] = sd.pop("text_projection")
|
||||
return super().load_sd(sd)
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
|
||||
@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return self.clip_g.load_sd(sd)
|
||||
else:
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
class SDXLRefinerClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu"):
|
||||
super().__init__()
|
||||
@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
return g_out, g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.clip_g.load_sd(sd)
|
||||
|
21
nodes.py
21
nodes.py
@ -520,11 +520,27 @@ class CLIPLoader:
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
class DualCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2):
|
||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
class CLIPVisionLoader:
|
||||
@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentCrop": LatentCrop,
|
||||
"LoraLoader": LoraLoader,
|
||||
"CLIPLoader": CLIPLoader,
|
||||
"DualCLIPLoader": DualCLIPLoader,
|
||||
"CLIPVisionEncode": CLIPVisionEncode,
|
||||
"StyleModelApply": StyleModelApply,
|
||||
"unCLIPConditioning": unCLIPConditioning,
|
||||
|
Loading…
Reference in New Issue
Block a user