mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Load the SD3 T5xxl model in the same dtype stored in the checkpoint.
This commit is contained in:
parent
5889b7ca0a
commit
0e49211a11
@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_cast(device, dtype): #TODO
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return True
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return True
|
||||||
|
if is_device_mps(device):
|
||||||
|
return False
|
||||||
|
if directml_enabled: #TODO: test this
|
||||||
|
return False
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return True
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return True
|
||||||
|
if dtype == torch.float8_e5m2:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def device_supports_non_blocking(device):
|
def device_supports_non_blocking(device):
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False #pytorch bug? mps doesn't support non blocking
|
return False #pytorch bug? mps doesn't support non blocking
|
||||||
|
@ -98,13 +98,19 @@ class CLIP:
|
|||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
params['device'] = offload_device
|
||||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
params['dtype'] = dtype
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
|
for dt in self.cond_stage_model.dtypes:
|
||||||
|
if not model_management.supports_cast(load_device, dt):
|
||||||
|
load_device = offload_device
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
|
@ -511,6 +511,10 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||||
|
|
||||||
|
self.dtypes = set()
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
getattr(self, self.clip).set_clip_options(options)
|
getattr(self, self.clip).set_clip_options(options)
|
||||||
|
|
||||||
|
@ -44,24 +44,36 @@ class SD3Tokenizer:
|
|||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||||
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
|
|
||||||
if clip_g:
|
if clip_g:
|
||||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||||
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_g = None
|
self.clip_g = None
|
||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
if dtype_t5 is None:
|
||||||
|
dtype_t5 = dtype
|
||||||
|
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
|
||||||
|
dtype_t5 = dtype
|
||||||
|
|
||||||
|
if not comfy.model_management.supports_cast(device, dtype_t5):
|
||||||
|
dtype_t5 = dtype
|
||||||
|
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||||
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
|
|
||||||
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5))
|
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
if self.clip_l is not None:
|
if self.clip_l is not None:
|
||||||
|
@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
self.clip_l.set_clip_options(options)
|
self.clip_l.set_clip_options(options)
|
||||||
|
@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE):
|
|||||||
clip_l = False
|
clip_l = False
|
||||||
clip_g = False
|
clip_g = False
|
||||||
t5 = False
|
t5 = False
|
||||||
|
dtype_t5 = None
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
clip_l = True
|
clip_l = True
|
||||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
clip_g = True
|
clip_g = True
|
||||||
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
|
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||||
|
if t5_key in state_dict:
|
||||||
t5 = True
|
t5 = True
|
||||||
|
dtype_t5 = state_dict[t5_key].dtype
|
||||||
|
|
||||||
class SD3ClipModel(sd3_clip.SD3ClipModel):
|
class SD3ClipModel(sd3_clip.SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||||
|
|
||||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
|
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user