mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Don't expand mask when not necessary.
Expanding seems to slow down inference.
This commit is contained in:
parent
61b50720d0
commit
19ee5d9d8b
@ -423,8 +423,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
# add a heads dimension if there isn't already one
|
# add a heads dimension if there isn't already one
|
||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
mask = mask.expand(b, heads, -1, -1)
|
|
||||||
|
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
@ -434,11 +432,16 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
else:
|
else:
|
||||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||||
|
m = mask
|
||||||
|
if mask is not None:
|
||||||
|
if mask.shape[0] > 1:
|
||||||
|
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||||
|
|
||||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q[i : i + SDP_BATCH_LIMIT],
|
q[i : i + SDP_BATCH_LIMIT],
|
||||||
k[i : i + SDP_BATCH_LIMIT],
|
k[i : i + SDP_BATCH_LIMIT],
|
||||||
v[i : i + SDP_BATCH_LIMIT],
|
v[i : i + SDP_BATCH_LIMIT],
|
||||||
attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT],
|
attn_mask=m,
|
||||||
dropout_p=0.0, is_causal=False
|
dropout_p=0.0, is_causal=False
|
||||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
Loading…
Reference in New Issue
Block a user