pytorch_attention_enabled can now return True when xformers is enabled.

This commit is contained in:
comfyanonymous 2023-10-11 21:29:03 -04:00
parent 20d3852aa1
commit 88733c997f
2 changed files with 7 additions and 4 deletions

View File

@ -355,7 +355,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
elif model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":

View File

@ -154,14 +154,18 @@ def is_nvidia():
return True
return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
XFORMERS_IS_AVAILABLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM