mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Fix mask issue in some attention functions.
This commit is contained in:
parent
8f0009aad0
commit
2fd9c1308a
@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
if len(mask.shape) == 2:
|
if len(mask.shape) == 2:
|
||||||
s1 += mask[i:end]
|
s1 += mask[i:end]
|
||||||
else:
|
else:
|
||||||
s1 += mask[:, i:end]
|
if mask.shape[1] == 1:
|
||||||
|
s1 += mask
|
||||||
|
else:
|
||||||
|
s1 += mask[:, i:end]
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
@ -234,6 +234,8 @@ def efficient_dot_product_attention(
|
|||||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||||
if mask is None:
|
if mask is None:
|
||||||
return None
|
return None
|
||||||
|
if mask.shape[1] == 1:
|
||||||
|
return mask
|
||||||
chunk = min(query_chunk_size, q_tokens)
|
chunk = min(query_chunk_size, q_tokens)
|
||||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user