mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support multiple text encoder configurations on SD3.
This commit is contained in:
parent
1c34d338d7
commit
5889b7ca0a
@ -482,7 +482,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
vae = VAE(sd=vae_sd)
|
vae = VAE(sd=vae_sd)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
clip_target = model_config.clip_target()
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
|
@ -5,6 +5,7 @@ import comfy.t5
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import logging
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
@ -43,20 +44,39 @@ class SD3Tokenizer:
|
|||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||||
|
else:
|
||||||
|
self.clip_l = None
|
||||||
|
|
||||||
|
if clip_g:
|
||||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.clip_g = None
|
||||||
|
|
||||||
|
if t5:
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.t5xxl = None
|
||||||
|
|
||||||
|
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5))
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
|
if self.clip_l is not None:
|
||||||
self.clip_l.set_clip_options(options)
|
self.clip_l.set_clip_options(options)
|
||||||
|
if self.clip_g is not None:
|
||||||
self.clip_g.set_clip_options(options)
|
self.clip_g.set_clip_options(options)
|
||||||
|
if self.t5xxl is not None:
|
||||||
self.t5xxl.set_clip_options(options)
|
self.t5xxl.set_clip_options(options)
|
||||||
|
|
||||||
def reset_clip_options(self):
|
def reset_clip_options(self):
|
||||||
self.clip_g.reset_clip_options()
|
if self.clip_l is not None:
|
||||||
self.clip_l.reset_clip_options()
|
self.clip_l.reset_clip_options()
|
||||||
|
if self.clip_g is not None:
|
||||||
|
self.clip_g.reset_clip_options()
|
||||||
|
if self.t5xxl is not None:
|
||||||
self.t5xxl.reset_clip_options()
|
self.t5xxl.reset_clip_options()
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
@ -64,22 +84,43 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
token_weight_pairs_g = token_weight_pairs["g"]
|
token_weight_pairs_g = token_weight_pairs["g"]
|
||||||
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||||
lg_out = None
|
lg_out = None
|
||||||
|
pooled = None
|
||||||
|
out = None
|
||||||
|
|
||||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
if self.clip_l is not None:
|
||||||
|
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
|
else:
|
||||||
|
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
if lg_out is not None:
|
||||||
|
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||||
|
else:
|
||||||
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
|
else:
|
||||||
|
g_out = None
|
||||||
|
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if lg_out is not None:
|
||||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
out = lg_out
|
out = lg_out
|
||||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
else:
|
|
||||||
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
|
|
||||||
|
|
||||||
|
if self.t5xxl is not None:
|
||||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
else:
|
else:
|
||||||
out = t5_out
|
out = t5_out
|
||||||
|
|
||||||
|
if out is None:
|
||||||
|
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if pooled is None:
|
||||||
|
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
return out, pooled
|
return out, pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
|
@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
|
|||||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||||
|
|
||||||
class SD20(supported_models_base.BASE):
|
class SD20(supported_models_base.BASE):
|
||||||
@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
|
|||||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||||
|
|
||||||
class SD21UnclipL(SD20):
|
class SD21UnclipL(SD20):
|
||||||
@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
return state_dict_g
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||||
|
|
||||||
class SDXL(supported_models_base.BASE):
|
class SDXL(supported_models_base.BASE):
|
||||||
@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
return state_dict_g
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||||
|
|
||||||
class SSD1B(SDXL):
|
class SSD1B(SDXL):
|
||||||
@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
|
|||||||
out = model_base.SVD_img2vid(self, device=device)
|
out = model_base.SVD_img2vid(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class SV3D_u(SVD_img2vid):
|
class SV3D_u(SVD_img2vid):
|
||||||
@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
|||||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class SD_X4Upscaler(SD20):
|
class SD_X4Upscaler(SD20):
|
||||||
@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|||||||
out = model_base.StableCascade_C(self, device=device)
|
out = model_base.StableCascade_C(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||||
|
|
||||||
class Stable_Cascade_B(Stable_Cascade_C):
|
class Stable_Cascade_B(Stable_Cascade_C):
|
||||||
@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE):
|
|||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.SD3
|
latent_format = latent_formats.SD3
|
||||||
text_encoder_key_prefix = ["text_encoders."] #TODO?
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.SD3(self, device=device)
|
out = model_base.SD3(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
|
clip_l = False
|
||||||
|
clip_g = False
|
||||||
|
t5 = False
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
|
clip_l = True
|
||||||
|
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
|
clip_g = True
|
||||||
|
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
|
t5 = True
|
||||||
|
|
||||||
|
class SD3ClipModel(sd3_clip.SD3ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None):
|
||||||
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
|
||||||
|
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
||||||
|
Loading…
Reference in New Issue
Block a user