From 0e49211a110fff099ebafa927ee7b9416ff9feaa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Jun 2024 17:03:26 -0400 Subject: [PATCH] Load the SD3 T5xxl model in the same dtype stored in the checkpoint. --- comfy/model_management.py | 17 +++++++++++++++++ comfy/sd.py | 8 +++++++- comfy/sd1_clip.py | 4 ++++ comfy/sd3_clip.py | 18 +++++++++++++++--- comfy/sdxl_clip.py | 1 + comfy/supported_models.py | 7 +++++-- 6 files changed, 49 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 57aa8bca..dbd0dbac 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO return True return False +def supports_cast(device, dtype): #TODO + if dtype == torch.float32: + return True + if dtype == torch.float16: + return True + if is_device_mps(device): + return False + if directml_enabled: #TODO: test this + return False + if dtype == torch.bfloat16: + return True + if dtype == torch.float8_e4m3fn: + return True + if dtype == torch.float8_e5m2: + return True + return False + def device_supports_non_blocking(device): if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking diff --git a/comfy/sd.py b/comfy/sd.py index 11764077..a7b4dbcf 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -98,13 +98,19 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = offload_device - params['dtype'] = model_management.text_encoder_dtype(load_device) + dtype = model_management.text_encoder_dtype(load_device) + params['dtype'] = dtype self.cond_stage_model = clip(**(params)) + for dt in self.cond_stage_model.dtypes: + if not model_management.supports_cast(load_device, dt): + load_device = offload_device + self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None + logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) def clone(self): n = CLIP(no_init=True) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 2729f14d..911af0a7 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -511,6 +511,10 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + self.dtypes = set() + if dtype is not None: + self.dtypes.add(dtype) + def set_clip_options(self, options): getattr(self, self.clip).set_clip_options(options) diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index 595381fc..cbbbe53d 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -44,24 +44,36 @@ class SD3Tokenizer: return self.clip_g.untokenize(token_weight_pair) class SD3ClipModel(torch.nn.Module): - def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None): + def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): super().__init__() + self.dtypes = set() if clip_l: self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) + self.dtypes.add(dtype) else: self.clip_l = None if clip_g: self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) + self.dtypes.add(dtype) else: self.clip_g = None if t5: - self.t5xxl = T5XXLModel(device=device, dtype=dtype) + if dtype_t5 is None: + dtype_t5 = dtype + elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype): + dtype_t5 = dtype + + if not comfy.model_management.supports_cast(device, dtype_t5): + dtype_t5 = dtype + + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) + self.dtypes.add(dtype_t5) else: self.t5xxl = None - logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5)) + logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5)) def set_clip_options(self, options): if self.clip_l is not None: diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index e62d1ed8..1257cba1 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module): super().__init__() self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) + self.dtypes = set([dtype]) def set_clip_options(self, options): self.clip_l.set_clip_options(options) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 481ecaa6..a49df7a3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE): clip_l = False clip_g = False t5 = False + dtype_t5 = None pref = self.text_encoder_key_prefix[0] if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: clip_l = True if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: clip_g = True - if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict: + t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + if t5_key in state_dict: t5 = True + dtype_t5 = state_dict[t5_key].dtype class SD3ClipModel(sd3_clip.SD3ClipModel): def __init__(self, device="cpu", dtype=None): - super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype) + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)