From 1a612e1c74ecb845350bbeab7554992e9f2c175c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 2 Mar 2023 17:01:20 -0500 Subject: [PATCH] Add some pytorch scaled_dot_product_attention code for testing. --use-pytorch-cross-attention to use it. --- comfy/ldm/modules/attention.py | 56 ++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5ee7d5ae..00a20782 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -442,14 +442,64 @@ class MemoryEfficientCrossAttention(nn.Module): ) return self.to_out(out) +class CrossAttentionPytorch(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + return self.to_out(out) + import sys -if XFORMERS_IS_AVAILBLE == False: +if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-split-cross-attention" in sys.argv: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan + if "--use-pytorch-cross-attention" in sys.argv: + print("Using pytorch cross attention") + torch.backends.cuda.enable_math_sdp(False) + CrossAttention = CrossAttentionPytorch + else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + CrossAttention = CrossAttentionBirchSan else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention