Enable AMD fp8 and pytorch attention on some GPUs. (#8474)

Information is from the pytorch source code.
This commit is contained in:
comfyanonymous 2025-06-09 09:50:39 -07:00 committed by GitHub
parent 97755eed46
commit 7f800d04fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -307,10 +307,10 @@ try:
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201"]): # TODO: more arches
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
except: