Fix attention mask batch size in some attention functions.

This commit is contained in:
comfyanonymous 2024-02-17 15:22:21 -05:00
parent 3b9969c1c5
commit f8706546f3

View File

@ -114,7 +114,11 @@ def attention_basic(q, k, v, heads, mask=None):
mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
else: else:
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape) if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
sim.add_(mask) sim.add_(mask)
# attention, what we cannot get enough of # attention, what we cannot get enough of
@ -167,7 +171,11 @@ def attention_sub_quad(query, key, value, heads, mask=None):
query_chunk_size = 512 query_chunk_size = 512
if mask is not None: if mask is not None:
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention( hidden_states = efficient_dot_product_attention(
query, query,
@ -228,7 +236,11 @@ def attention_split(q, k, v, heads, mask=None):
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
if mask is not None: if mask is not None:
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False first_op_done = False