mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Make supported_dtypes a priority list.
This commit is contained in:
parent
cb7c4b4be3
commit
6969fc9ba4
@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
|
|
||||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
for dt in supported_dtypes:
|
||||||
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||||
if torch.float16 in supported_dtypes:
|
if torch.float16 in supported_dtypes:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
||||||
if torch.bfloat16 in supported_dtypes:
|
if torch.bfloat16 in supported_dtypes:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
|
|
||||||
|
for dt in supported_dtypes:
|
||||||
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.float16 in supported_dtypes:
|
||||||
|
return torch.float16
|
||||||
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
# None means no manual cast
|
# None means no manual cast
|
||||||
@ -583,12 +593,12 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if fp16_supported and torch.float16 in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
|
if dt == torch.float16 and fp16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
if dt == torch.bfloat16 and bf16_supported:
|
||||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
else:
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
|
Loading…
Reference in New Issue
Block a user