diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ac0d9c8c..9cd14a53 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None): ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - - if exists(mask): - raise NotImplementedError out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) return out + optimized_attention = attention_basic +optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -309,6 +308,9 @@ else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +if model_management.pytorch_attention_enabled(): + optimized_attention_masked = attention_pytorch + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() @@ -334,7 +336,10 @@ class CrossAttention(nn.Module): else: v = self.to_v(context) - out = optimized_attention(q, k, v, self.heads, mask) + if mask is None: + out = optimized_attention(q, k, v, self.heads) + else: + out = optimized_attention_masked(q, k, v, self.heads, mask) return self.to_out(out)