mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-17 17:43:30 +00:00
Add a separate optimized_attention_masked function.
This commit is contained in:
parent
7e09e889e3
commit
bb064c9796
@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention_masked = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
@ -309,6 +308,9 @@ else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
@ -334,7 +336,10 @@ class CrossAttention(nn.Module):
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads, mask)
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user