Add attention mask support to sub quad attention.

This commit is contained in:
comfyanonymous 2024-01-07 04:13:58 -05:00
parent 0c2c9fbdfa
commit aaa9017302
2 changed files with 27 additions and 4 deletions

View File

@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min=kv_chunk_size_min, kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False, use_checkpoint=False,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
mask=mask,
) )
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)

View File

@ -61,6 +61,7 @@ def _summarize_chunk(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> AttnChunk: ) -> AttnChunk:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -84,6 +85,8 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
attn_weights -= max_score attn_weights -= max_score
if mask is not None:
attn_weights += mask
torch.exp(attn_weights, out=attn_weights) torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype) exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value)
@ -96,11 +99,12 @@ def _query_chunk_attention(
value: Tensor, value: Tensor,
summarize_chunk: SummarizeChunk, summarize_chunk: SummarizeChunk,
kv_chunk_size: int, kv_chunk_size: int,
mask,
) -> Tensor: ) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape _, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk: def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
key_chunk = dynamic_slice( key_chunk = dynamic_slice(
key_t, key_t,
(0, 0, chunk_idx), (0, 0, chunk_idx),
@ -111,10 +115,13 @@ def _query_chunk_attention(
(0, chunk_idx, 0), (0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head) (batch_x_heads, kv_chunk_size, v_channels_per_head)
) )
return summarize_chunk(query, key_chunk, value_chunk) if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
chunks: List[AttnChunk] = [ chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
] ]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk chunk_values, chunk_weights, chunk_max = acc_chunk
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> Tensor: ) -> Tensor:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
beta=0, beta=0,
) )
if mask is not None:
attn_scores += mask
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None, kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True, use_checkpoint=True,
upcast_attention=False, upcast_attention=False,
mask = None,
): ):
"""Computes efficient dot-product attention given query, transposed key, and value. """Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in This is efficient version of attention presented in
@ -209,6 +220,9 @@ def efficient_dot_product_attention(
if kv_chunk_size_min is not None: if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
if mask is not None and len(mask.shape) == 2:
mask = mask.unsqueeze(0)
def get_query_chunk(chunk_idx: int) -> Tensor: def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice( return dynamic_slice(
query, query,
@ -216,6 +230,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
) )
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
query=query, query=query,
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=mask,
) )
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size)) ) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1) ], dim=1)
return res return res