Support multiple text encoder configurations on SD3.

This commit is contained in:
comfyanonymous 2024-06-11 13:14:43 -04:00
parent 1c34d338d7
commit 5889b7ca0a
3 changed files with 91 additions and 35 deletions

View File

@ -482,7 +482,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd)
if output_clip: if output_clip:
clip_target = model_config.clip_target() clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None: if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:

View File

@ -5,6 +5,7 @@ import comfy.t5
import torch import torch
import os import os
import comfy.model_management import comfy.model_management
import logging
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
@ -43,20 +44,39 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None):
super().__init__() super().__init__()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
else:
self.clip_g = None
if t5:
self.t5xxl = T5XXLModel(device=device, dtype=dtype) self.t5xxl = T5XXLModel(device=device, dtype=dtype)
else:
self.t5xxl = None
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5))
def set_clip_options(self, options): def set_clip_options(self, options):
if self.clip_l is not None:
self.clip_l.set_clip_options(options) self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options) self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options) self.t5xxl.set_clip_options(options)
def reset_clip_options(self): def reset_clip_options(self):
self.clip_g.reset_clip_options() if self.clip_l is not None:
self.clip_l.reset_clip_options() self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options() self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
@ -64,22 +84,43 @@ class SD3ClipModel(torch.nn.Module):
token_weight_pairs_g = token_weight_pairs["g"] token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"] token_weight_pars_t5 = token_weight_pairs["t5xxl"]
lg_out = None lg_out = None
pooled = None
out = None
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
lg_out = torch.cat([l_out, g_out], dim=-1) if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
g_out = None
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
if lg_out is not None:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1) pooled = torch.cat((l_pooled, g_pooled), dim=-1)
else:
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None: if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2) out = torch.cat([lg_out, t5_out], dim=-2)
else: else:
out = t5_out out = t5_out
if out is None:
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
return out, pooled return out, pooled
def load_sd(self, sd): def load_sd(self, sd):

View File

@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
replace_prefix = {"clip_l.": "cond_stage_model."} replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
class SD20(supported_models_base.BASE): class SD20(supported_models_base.BASE):
@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict return state_dict
def clip_target(self): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20): class SD21UnclipL(SD20):
@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g return state_dict_g
def clip_target(self): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
class SDXL(supported_models_base.BASE): class SDXL(supported_models_base.BASE):
@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g return state_dict_g
def clip_target(self): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
class SSD1B(SDXL): class SSD1B(SDXL):
@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
out = model_base.SVD_img2vid(self, device=device) out = model_base.SVD_img2vid(self, device=device)
return out return out
def clip_target(self): def clip_target(self, state_dict={}):
return None return None
class SV3D_u(SVD_img2vid): class SV3D_u(SVD_img2vid):
@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out return out
def clip_target(self): def clip_target(self, state_dict={}):
return None return None
class SD_X4Upscaler(SD20): class SD_X4Upscaler(SD20):
@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
out = model_base.StableCascade_C(self, device=device) out = model_base.StableCascade_C(self, device=device)
return out return out
def clip_target(self): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
class Stable_Cascade_B(Stable_Cascade_C): class Stable_Cascade_B(Stable_Cascade_C):
@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.SD3 latent_format = latent_formats.SD3
text_encoder_key_prefix = ["text_encoders."] #TODO? text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device) out = model_base.SD3(self, device=device)
return out return out
def clip_target(self): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO? clip_l = False
clip_g = False
t5 = False
pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
t5 = True
class SD3ClipModel(sd3_clip.SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]