Don't expand mask when not necessary.

Expanding seems to slow down inference.
This commit is contained in:
comfyanonymous 2024-12-16 18:22:50 -05:00
parent 61b50720d0
commit 19ee5d9d8b

View File

@ -423,8 +423,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
# add a heads dimension if there isn't already one # add a heads dimension if there isn't already one
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
mask = mask.expand(b, heads, -1, -1)
if SDP_BATCH_LIMIT >= b: if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
@ -434,11 +432,16 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
else: else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT): for i in range(0, b, SDP_BATCH_LIMIT):
m = mask
if mask is not None:
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT], q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT],
attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT], attn_mask=m,
dropout_p=0.0, is_causal=False dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out