Fix pytorch 2.0 cross attention not working.

This commit is contained in:
comfyanonymous 2023-03-05 14:14:54 -05:00
parent f9d09c266f
commit 798c90e1c0

View File

@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
if "--use-pytorch-cross-attention" in sys.argv:
print("Using pytorch cross attention")
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
CrossAttention = CrossAttentionPytorch
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
@ -497,6 +499,7 @@ else:
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):