Allow FP16 accumulation with --fast (#6453)

Currently only applies to PyTorch nightly releases. (>=20250208)
This commit is contained in:
catboxanon 2025-02-08 17:00:56 -05:00 committed by GitHub
parent af93c8d1ee
commit 43a74c0de1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -241,6 +241,12 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
try:
if is_nvidia() and args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
except:
pass
try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)