mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add attention mask support to sub quad attention.
This commit is contained in:
parent
0c2c9fbdfa
commit
aaa9017302
@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=False,
|
||||
upcast_attention=upcast_attention,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
@ -61,6 +61,7 @@ def _summarize_chunk(
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> AttnChunk:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
@ -84,6 +85,8 @@ def _summarize_chunk(
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
attn_weights -= max_score
|
||||
if mask is not None:
|
||||
attn_weights += mask
|
||||
torch.exp(attn_weights, out=attn_weights)
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
@ -96,11 +99,12 @@ def _query_chunk_attention(
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
|
||||
key_chunk = dynamic_slice(
|
||||
key_t,
|
||||
(0, 0, chunk_idx),
|
||||
@ -111,10 +115,13 @@ def _query_chunk_attention(
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||
)
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
if mask is not None:
|
||||
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
||||
|
||||
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
|
||||
beta=0,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
attn_scores += mask
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
upcast_attention=False,
|
||||
mask = None,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||
This is efficient version of attention presented in
|
||||
@ -209,13 +220,22 @@ def efficient_dot_product_attention(
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
if mask is not None and len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return dynamic_slice(
|
||||
query,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||
)
|
||||
|
||||
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
|
||||
query=query,
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=get_mask_chunk(i * query_chunk_size)
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
||||
|
Loading…
Reference in New Issue
Block a user