Add a separate optimized_attention_masked function.

This commit is contained in:
comfyanonymous 2023-10-16 02:31:24 -04:00
parent 7e09e889e3
commit bb064c9796

View File

@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None):
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
@ -309,6 +308,9 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
@ -334,7 +336,10 @@ class CrossAttention(nn.Module):
else:
v = self.to_v(context)
out = optimized_attention(q, k, v, self.heads, mask)
if mask is None:
out = optimized_attention(q, k, v, self.heads)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask)
return self.to_out(out)