diff --git a/comfy/model_management.py b/comfy/model_management.py index b0640d67..39df8d9a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -275,8 +275,17 @@ def xformers_enabled_vae(): return XFORMERS_ENABLED_VAE def pytorch_attention_enabled(): + global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION +def pytorch_attention_flash_attention(): + global ENABLE_PYTORCH_ATTENTION + if ENABLE_PYTORCH_ATTENTION: + #TODO: more reliable way of checking for flash attention? + if torch.version.cuda: #pytorch flash attention only works on Nvidia + return True + return False + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled @@ -312,9 +321,9 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - if xformers_enabled(): + if xformers_enabled() or pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - area = 50 * memory_free + area = 20 * memory_free else: #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future area = ((memory_free - 1024) * 0.9) / (0.6)