From 6bcf57ff10f4488f7dfc6e3c47ac516967567c22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 16:15:18 -0500 Subject: [PATCH] Fix attention masks properly for multiple batches. --- comfy/ldm/modules/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f1dca2c2..48399bc0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -118,7 +118,7 @@ def attention_basic(q, k, v, heads, mask=None): bs = 1 else: bs = mask.shape[0] - mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) sim.add_(mask) # attention, what we cannot get enough of @@ -175,7 +175,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): bs = 1 else: bs = mask.shape[0] - mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) hidden_states = efficient_dot_product_attention( query, @@ -240,7 +240,7 @@ def attention_split(q, k, v, heads, mask=None): bs = 1 else: bs = mask.shape[0] - mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False