Single and dual clip loader nodes support SD3.

You can use the CLIPLoader to use the t5xxl only or the DualCLIPLoader to
use CLIP-L and CLIP-G only for sd3.
This commit is contained in:
comfyanonymous 2024-06-11 23:27:39 -04:00
parent 0e49211a11
commit 69c8d6d8a6
4 changed files with 31 additions and 11 deletions

View File

@ -370,6 +370,7 @@ def load_style_model(ckpt_path):
class CLIPType(Enum): class CLIPType(Enum):
STABLE_DIFFUSION = 1 STABLE_DIFFUSION = 1
STABLE_CASCADE = 2 STABLE_CASCADE = 2
SD3 = 3
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = [] clip_data = []
@ -399,12 +400,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2: elif len(clip_data) == 2:
clip_target.clip = sdxl_clip.SDXLClipModel if clip_type == CLIPType.SD3:
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3: elif len(clip_data) == 3:
clip_target.clip = sd3_clip.SD3ClipModel clip_target.clip = sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = sd3_clip.SD3Tokenizer

View File

@ -142,3 +142,9 @@ class SD3ClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd) return self.clip_l.load_sd(sd)
else: else:
return self.t5xxl.load_sd(sd) return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return SD3ClipModel_

View File

@ -522,11 +522,7 @@ class SD3(supported_models_base.BASE):
t5 = True t5 = True
dtype_t5 = state_dict[t5_key].dtype dtype_t5 = state_dict[t5_key].dtype
class SD3ClipModel(sd3_clip.SD3ClipModel): return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]

View File

@ -818,7 +818,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade"], ), "type": (["stable_diffusion", "stable_cascade", "sd3"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@ -829,6 +829,8 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade": if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
@ -837,17 +839,24 @@ class CLIPLoader:
class DualCLIPLoader: class DualCLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
"clip_name2": (folder_paths.get_filename_list("clip"), ),
"type": (["sdxl", "sd3"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2): def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
class CLIPVisionLoader: class CLIPVisionLoader: