Fix attention masks properly for multiple batches.

This commit is contained in:
comfyanonymous 2024-02-17 16:15:18 -05:00
parent 11e3221f1f
commit 6bcf57ff10

View File

@ -118,7 +118,7 @@ def attention_basic(q, k, v, heads, mask=None):
bs = 1 bs = 1
else: else:
bs = mask.shape[0] bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
sim.add_(mask) sim.add_(mask)
# attention, what we cannot get enough of # attention, what we cannot get enough of
@ -175,7 +175,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
bs = 1 bs = 1
else: else:
bs = mask.shape[0] bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention( hidden_states = efficient_dot_product_attention(
query, query,
@ -240,7 +240,7 @@ def attention_split(q, k, v, heads, mask=None):
bs = 1 bs = 1
else: else:
bs = mask.shape[0] bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False first_op_done = False