diff --git a/comfy/sd.py b/comfy/sd.py index cb147fa4..11764077 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -482,7 +482,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o vae = VAE(sd=vae_sd) if output_clip: - clip_target = model_config.clip_target() + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index bbbf6aff..595381fc 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -5,6 +5,7 @@ import comfy.t5 import torch import os import comfy.model_management +import logging class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): @@ -43,42 +44,82 @@ class SD3Tokenizer: return self.clip_g.untokenize(token_weight_pair) class SD3ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None): + def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) - self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) - self.t5xxl = T5XXLModel(device=device, dtype=dtype) + if clip_l: + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) + else: + self.clip_g = None + + if t5: + self.t5xxl = T5XXLModel(device=device, dtype=dtype) + else: + self.t5xxl = None + + logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5)) def set_clip_options(self, options): - self.clip_l.set_clip_options(options) - self.clip_g.set_clip_options(options) - self.t5xxl.set_clip_options(options) + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) def reset_clip_options(self): - self.clip_g.reset_clip_options() - self.clip_l.reset_clip_options() - self.t5xxl.reset_clip_options() + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_g = token_weight_pairs["g"] token_weight_pars_t5 = token_weight_pairs["t5xxl"] lg_out = None - if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: - l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - out = lg_out - pooled = torch.cat((l_pooled, g_pooled), dim=-1) - else: - pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device()) + pooled = None + out = None - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) - if lg_out is not None: - out = torch.cat([lg_out, t5_out], dim=-2) - else: - out = t5_out + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + if lg_out is not None: + lg_out = torch.cat([lg_out, g_out], dim=-1) + else: + lg_out = torch.nn.functional.pad(g_out, (768, 0)) + else: + g_out = None + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + if lg_out is not None: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + out = lg_out + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + if lg_out is not None: + out = torch.cat([lg_out, t5_out], dim=-2) + else: + out = t5_out + + if out is None: + out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) return out, pooled diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6bb76c96..481ecaa6 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE): replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) class SD20(supported_models_base.BASE): @@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE): state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) class SD21UnclipL(SD20): @@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE): state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) class SDXL(supported_models_base.BASE): @@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE): state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) class SSD1B(SDXL): @@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE): out = model_base.SVD_img2vid(self, device=device) return out - def clip_target(self): + def clip_target(self, state_dict={}): return None class SV3D_u(SVD_img2vid): @@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE): out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) return out - def clip_target(self): + def clip_target(self, state_dict={}): return None class SD_X4Upscaler(SD20): @@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE): out = model_base.StableCascade_C(self, device=device) return out - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) class Stable_Cascade_B(Stable_Cascade_C): @@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 - text_encoder_key_prefix = ["text_encoders."] #TODO? + text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): out = model_base.SD3(self, device=device) return out - def clip_target(self): - return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO? + def clip_target(self, state_dict={}): + clip_l = False + clip_g = False + t5 = False + pref = self.text_encoder_key_prefix[0] + if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + clip_l = True + if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + clip_g = True + if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict: + t5 = True + + class SD3ClipModel(sd3_clip.SD3ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype) + + return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]