Prioritize fp16 compute when using allow_fp16_accumulation

This commit is contained in:
comfyanonymous 2025-02-23 04:45:54 -05:00
parent aff16532d4
commit ace899e71a

View File

@ -256,9 +256,12 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try: try:
if is_nvidia() and args.fast: if is_nvidia() and args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
except: except:
pass pass
@ -681,6 +684,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory: if model_params * 2 > free_model_memory:
return fp8_dtype return fp8_dtype
if PRIORITIZE_FP16:
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
return torch.float16
for dt in supported_dtypes: for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params): if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes: if torch.float16 in supported_dtypes: