diff --git a/comfy/model_management.py b/comfy/model_management.py index 0d5b0730..855e8911 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -647,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor pass if fp8_dtype is not None: + if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive + return fp8_dtype + free_model_memory = maximum_vram_for_weights(device) if model_params * 2 > free_model_memory: return fp8_dtype