From ace899e71a3d8d75f64a016fa398b95fa83e6978 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Feb 2025 04:45:54 -0500 Subject: [PATCH] Prioritize fp16 compute when using allow_fp16_accumulation --- comfy/model_management.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8b6c4a66..f4a63c6d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -256,9 +256,12 @@ if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) + +PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: if is_nvidia() and args.fast: torch.backends.cuda.matmul.allow_fp16_accumulation = True + PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance except: pass @@ -681,6 +684,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if model_params * 2 > free_model_memory: return fp8_dtype + if PRIORITIZE_FP16: + if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params): + return torch.float16 + for dt in supported_dtypes: if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params): if torch.float16 in supported_dtypes: