mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Fix attention mask batch size in some attention functions.
This commit is contained in:
parent
3b9969c1c5
commit
f8706546f3
@ -114,7 +114,11 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
else:
|
else:
|
||||||
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape)
|
if len(mask.shape) == 2:
|
||||||
|
bs = 1
|
||||||
|
else:
|
||||||
|
bs = mask.shape[0]
|
||||||
|
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||||
sim.add_(mask)
|
sim.add_(mask)
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
@ -167,7 +171,11 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
query_chunk_size = 512
|
query_chunk_size = 512
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
if len(mask.shape) == 2:
|
||||||
|
bs = 1
|
||||||
|
else:
|
||||||
|
bs = mask.shape[0]
|
||||||
|
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||||
|
|
||||||
hidden_states = efficient_dot_product_attention(
|
hidden_states = efficient_dot_product_attention(
|
||||||
query,
|
query,
|
||||||
@ -228,7 +236,11 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
if len(mask.shape) == 2:
|
||||||
|
bs = 1
|
||||||
|
else:
|
||||||
|
bs = mask.shape[0]
|
||||||
|
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||||
|
|
||||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||||
first_op_done = False
|
first_op_done = False
|
||||||
|
Loading…
Reference in New Issue
Block a user