Add argument to skip the output reshaping in the attention functions.

This commit is contained in:
comfyanonymous 2025-01-10 06:27:37 -05:00
parent ff838657fa
commit 129d8908f7

View File

@ -89,7 +89,7 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@ -142,16 +142,23 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
)
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
if skip_output_reshape:
hidden_states = hidden_states.unflatten(0, (-1, heads))
else:
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@ -326,12 +335,18 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
BROKEN_XFORMERS = False
@ -342,7 +357,7 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -395,9 +410,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.reshape(b, -1, heads * dim_head)
)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
@ -408,7 +426,7 @@ else:
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -429,9 +447,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -473,11 +492,15 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out