diff --git a/comfy/sd.py b/comfy/sd.py index a7b4dbcf..3fd9e0e9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -370,6 +370,7 @@ def load_style_model(ckpt_path): class CLIPType(Enum): STABLE_DIFFUSION = 1 STABLE_CASCADE = 2 + SD3 = 3 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] @@ -399,12 +400,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: + dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype + clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) + clip_target.tokenizer = sd3_clip.SD3Tokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: - clip_target.clip = sdxl_clip.SDXLClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.SD3: + clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) + clip_target.tokenizer = sd3_clip.SD3Tokenizer + else: + clip_target.clip = sdxl_clip.SDXLClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif len(clip_data) == 3: clip_target.clip = sd3_clip.SD3ClipModel clip_target.tokenizer = sd3_clip.SD3Tokenizer diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index cbbbe53d..0713eb28 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -142,3 +142,9 @@ class SD3ClipModel(torch.nn.Module): return self.clip_l.load_sd(sd) else: return self.t5xxl.load_sd(sd) + +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): + class SD3ClipModel_(SD3ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) + return SD3ClipModel_ diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a49df7a3..c8ddf3e2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -522,11 +522,7 @@ class SD3(supported_models_base.BASE): t5 = True dtype_t5 = state_dict[t5_key].dtype - class SD3ClipModel(sd3_clip.SD3ClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) - - return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel) + return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] diff --git a/nodes.py b/nodes.py index ef1f8561..6fbeb377 100644 --- a/nodes.py +++ b/nodes.py @@ -818,7 +818,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), - "type": (["stable_diffusion", "stable_cascade"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -829,6 +829,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION if type == "stable_cascade": clip_type = comfy.sd.CLIPType.STABLE_CASCADE + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 clip_path = folder_paths.get_full_path("clip", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) @@ -837,17 +839,24 @@ class CLIPLoader: class DualCLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), + return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), + "clip_name2": (folder_paths.get_filename_list("clip"), ), + "type": (["sdxl", "sd3"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name1, clip_name2): + def load_clip(self, clip_name1, clip_name2, type): clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path2 = folder_paths.get_full_path("clip", clip_name2) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) + if type == "sdxl": + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class CLIPVisionLoader: