mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 02:43:30 +00:00
Add a separate optimized_attention_masked function.
This commit is contained in:
parent
7e09e889e3
commit
bb064c9796
@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
raise NotImplementedError
|
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
optimized_attention_masked = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
print("Using xformers cross attention")
|
print("Using xformers cross attention")
|
||||||
@ -309,6 +308,9 @@ else:
|
|||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
|
if model_management.pytorch_attention_enabled():
|
||||||
|
optimized_attention_masked = attention_pytorch
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -334,7 +336,10 @@ class CrossAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, self.heads, mask)
|
if mask is None:
|
||||||
|
out = optimized_attention(q, k, v, self.heads)
|
||||||
|
else:
|
||||||
|
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user