Load the SD3 T5xxl model in the same dtype stored in the checkpoint.

This commit is contained in:
comfyanonymous 2024-06-11 17:03:26 -04:00
parent 5889b7ca0a
commit 0e49211a11
6 changed files with 49 additions and 6 deletions

View File

@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO
return True return True
return False return False
def supports_cast(device, dtype): #TODO
if dtype == torch.float32:
return True
if dtype == torch.float16:
return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this
return False
if dtype == torch.bfloat16:
return True
if dtype == torch.float8_e4m3fn:
return True
if dtype == torch.float8_e5m2:
return True
return False
def device_supports_non_blocking(device): def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking return False #pytorch bug? mps doesn't support non blocking

View File

@ -98,13 +98,19 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device params['device'] = offload_device
params['dtype'] = model_management.text_encoder_dtype(load_device) dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)

View File

@ -511,6 +511,10 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def set_clip_options(self, options): def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options) getattr(self, self.clip).set_clip_options(options)

View File

@ -44,24 +44,36 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__() super().__init__()
self.dtypes = set()
if clip_l: if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.dtypes.add(dtype)
else: else:
self.clip_l = None self.clip_l = None
if clip_g: if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.dtypes.add(dtype)
else: else:
self.clip_g = None self.clip_g = None
if t5: if t5:
self.t5xxl = T5XXLModel(device=device, dtype=dtype) if dtype_t5 is None:
dtype_t5 = dtype
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
dtype_t5 = dtype
if not comfy.model_management.supports_cast(device, dtype_t5):
dtype_t5 = dtype
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes.add(dtype_t5)
else: else:
self.t5xxl = None self.t5xxl = None
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5)) logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
def set_clip_options(self, options): def set_clip_options(self, options):
if self.clip_l is not None: if self.clip_l is not None:

View File

@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
self.dtypes = set([dtype])
def set_clip_options(self, options): def set_clip_options(self, options):
self.clip_l.set_clip_options(options) self.clip_l.set_clip_options(options)

View File

@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE):
clip_l = False clip_l = False
clip_g = False clip_g = False
t5 = False t5 = False
dtype_t5 = None
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True clip_g = True
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict: t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
t5 = True t5 = True
dtype_t5 = state_dict[t5_key].dtype
class SD3ClipModel(sd3_clip.SD3ClipModel): class SD3ClipModel(sd3_clip.SD3ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype) super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel) return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)