Make maximum_batch_area take into account python2.0 attention function.

More conservative xformers maximum_batch_area.
This commit is contained in:
comfyanonymous 2023-05-06 19:58:54 -04:00
parent 678f933d38
commit 6fc4917634

View File

@ -275,8 +275,17 @@ def xformers_enabled_vae():
return XFORMERS_ENABLED_VAE return XFORMERS_ENABLED_VAE
def pytorch_attention_enabled(): def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if torch.version.cuda: #pytorch flash attention only works on Nvidia
return True
return False
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
@ -312,9 +321,9 @@ def maximum_batch_area():
return 0 return 0
memory_free = get_free_memory() / (1024 * 1024) memory_free = get_free_memory() / (1024 * 1024)
if xformers_enabled(): if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = 50 * memory_free area = 20 * memory_free
else: else:
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
area = ((memory_free - 1024) * 0.9) / (0.6) area = ((memory_free - 1024) * 0.9) / (0.6)