mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 13:17:32 +00:00
Add DualClipLoader to load clip models for SDXL.
Update LoadClip to load clip models for SDXL refiner.
This commit is contained in:
parent
b7933960bb
commit
20f579d91d
33
comfy/sd.py
33
comfy/sd.py
@ -19,6 +19,7 @@ from . import model_detection
|
|||||||
|
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
|
from . import sdxl_clip
|
||||||
|
|
||||||
def load_model_weights(model, sd):
|
def load_model_weights(model, sd):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -524,7 +525,7 @@ class CLIP:
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
def load_from_state_dict(self, sd):
|
def load_from_state_dict(self, sd):
|
||||||
self.cond_stage_model.transformer.load_state_dict(sd, strict=False)
|
self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength=1.0):
|
||||||
return self.patcher.add_patches(patches, strength)
|
return self.patcher.add_patches(patches, strength)
|
||||||
@ -555,6 +556,8 @@ class CLIP:
|
|||||||
tokens = self.tokenize(text)
|
tokens = self.tokenize(text)
|
||||||
return self.encode_from_tokens(tokens)
|
return self.encode_from_tokens(tokens)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||||
@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
|
|||||||
return StyleModel(model)
|
return StyleModel(model)
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_paths, embedding_directory=None):
|
||||||
clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
clip_data = []
|
||||||
|
for p in ckpt_paths:
|
||||||
|
clip_data.append(utils.load_torch_file(p, safe_load=True))
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
for i in range(len(clip_data)):
|
||||||
|
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
|
||||||
|
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)
|
||||||
|
|
||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
if len(clip_data) == 1:
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||||
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
|
else:
|
||||||
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
clip.load_from_state_dict(clip_data)
|
for c in clip_data:
|
||||||
|
m, u = clip.load_sd(c)
|
||||||
|
if len(m) > 0:
|
||||||
|
print("clip missing:", m)
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
print("clip unexpected:", u)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
|
@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.transformer.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
current_item = ""
|
current_item = ""
|
||||||
|
@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
|||||||
self.layer = "hidden"
|
self.layer = "hidden"
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_projection" in sd:
|
||||||
|
self.text_projection[:] = sd.pop("text_projection")
|
||||||
|
return super().load_sd(sd)
|
||||||
|
|
||||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
|
||||||
@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(torch.nn.Module):
|
class SDXLRefinerClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
|
|||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
return g_out, g_pooled
|
return g_out, g_pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
21
nodes.py
21
nodes.py
@ -520,11 +520,27 @@ class CLIPLoader:
|
|||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
return (clip,)
|
||||||
|
|
||||||
|
class DualCLIPLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
def load_clip(self, clip_name1, clip_name2):
|
||||||
|
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||||
|
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class CLIPVisionLoader:
|
class CLIPVisionLoader:
|
||||||
@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
|
"DualCLIPLoader": DualCLIPLoader,
|
||||||
"CLIPVisionEncode": CLIPVisionEncode,
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
"unCLIPConditioning": unCLIPConditioning,
|
"unCLIPConditioning": unCLIPConditioning,
|
||||||
|
Loading…
Reference in New Issue
Block a user