diff --git a/comfy/sd.py b/comfy/sd.py index 6feb0de4..3f36b8c0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -19,6 +19,7 @@ from . import model_detection from . import sd1_clip from . import sd2_clip +from . import sdxl_clip def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -524,7 +525,7 @@ class CLIP: return n def load_from_state_dict(self, sd): - self.cond_stage_model.transformer.load_state_dict(sd, strict=False) + self.cond_stage_model.load_sd(sd) def add_patches(self, patches, strength=1.0): return self.patcher.add_patches(patches, strength) @@ -555,6 +556,8 @@ class CLIP: tokens = self.tokenize(text) return self.encode_from_tokens(tokens) + def load_sd(self, sd): + return self.cond_stage_model.load_sd(sd) class VAE: def __init__(self, ckpt_path=None, device=None, config=None): @@ -959,22 +962,42 @@ def load_style_model(ckpt_path): return StyleModel(model) -def load_clip(ckpt_path, embedding_directory=None): - clip_data = utils.load_torch_file(ckpt_path, safe_load=True) +def load_clip(ckpt_paths, embedding_directory=None): + clip_data = [] + for p in ckpt_paths: + clip_data.append(utils.load_torch_file(p, safe_load=True)) + class EmptyClass: pass + for i in range(len(clip_data)): + if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: + clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_target = EmptyClass() clip_target.params = {} - if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - clip_target.clip = sd2_clip.SD2ClipModel - clip_target.tokenizer = sd2_clip.SD2Tokenizer + if len(clip_data) == 1: + if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer + elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: + clip_target.clip = sd2_clip.SD2ClipModel + clip_target.tokenizer = sd2_clip.SD2Tokenizer + else: + clip_target.clip = sd1_clip.SD1ClipModel + clip_target.tokenizer = sd1_clip.SD1Tokenizer else: - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer + clip_target.clip = sdxl_clip.SDXLClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip = CLIP(clip_target, embedding_directory=embedding_directory) - clip.load_from_state_dict(clip_data) + for c in clip_data: + m, u = clip.load_sd(c) + if len(m) > 0: + print("clip missing:", m) + + if len(u) > 0: + print("clip unexpected:", u) return clip def load_gligen(ckpt_path): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6a90b389..0ee314ad 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): def encode(self, tokens): return self(tokens) + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False) + def parse_parentheses(string): result = [] current_item = "" diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 7ab8a8ad..f251168d 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel): self.layer = "hidden" self.layer_idx = layer_idx + def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + return super().load_sd(sd) + class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280) @@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module): l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return torch.cat([l_out, g_out], dim=-1), g_pooled + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + else: + return self.clip_l.load_sd(sd) + class SDXLRefinerClipModel(torch.nn.Module): def __init__(self, device="cpu"): super().__init__() @@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module): g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) return g_out, g_pooled + def load_sd(self, sd): + return self.clip_g.load_sd(sd) diff --git a/nodes.py b/nodes.py index ce3e3b1e..c565501a 100644 --- a/nodes.py +++ b/nodes.py @@ -520,11 +520,27 @@ class CLIPLoader: RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" - CATEGORY = "loaders" + CATEGORY = "advanced/loaders" def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + +class DualCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name1, clip_name2): + clip_path1 = folder_paths.get_full_path("clip", clip_name1) + clip_path2 = folder_paths.get_full_path("clip", clip_name2) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) class CLIPVisionLoader: @@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "DualCLIPLoader": DualCLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply, "unCLIPConditioning": unCLIPConditioning,