mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Make maximum_batch_area take into account python2.0 attention function.
More conservative xformers maximum_batch_area.
This commit is contained in:
parent
678f933d38
commit
6fc4917634
@ -275,8 +275,17 @@ def xformers_enabled_vae():
|
|||||||
return XFORMERS_ENABLED_VAE
|
return XFORMERS_ENABLED_VAE
|
||||||
|
|
||||||
def pytorch_attention_enabled():
|
def pytorch_attention_enabled():
|
||||||
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
return ENABLE_PYTORCH_ATTENTION
|
return ENABLE_PYTORCH_ATTENTION
|
||||||
|
|
||||||
|
def pytorch_attention_flash_attention():
|
||||||
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
|
#TODO: more reliable way of checking for flash attention?
|
||||||
|
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@ -312,9 +321,9 @@ def maximum_batch_area():
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
memory_free = get_free_memory() / (1024 * 1024)
|
memory_free = get_free_memory() / (1024 * 1024)
|
||||||
if xformers_enabled():
|
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = 50 * memory_free
|
area = 20 * memory_free
|
||||||
else:
|
else:
|
||||||
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
|
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
|
||||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||||
|
Loading…
Reference in New Issue
Block a user