mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Add attention mask support to sub quad attention.
This commit is contained in:
parent
0c2c9fbdfa
commit
aaa9017302
@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
kv_chunk_size_min=kv_chunk_size_min,
|
kv_chunk_size_min=kv_chunk_size_min,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
@ -61,6 +61,7 @@ def _summarize_chunk(
|
|||||||
value: Tensor,
|
value: Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
upcast_attention: bool,
|
upcast_attention: bool,
|
||||||
|
mask,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
@ -84,6 +85,8 @@ def _summarize_chunk(
|
|||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
attn_weights -= max_score
|
attn_weights -= max_score
|
||||||
|
if mask is not None:
|
||||||
|
attn_weights += mask
|
||||||
torch.exp(attn_weights, out=attn_weights)
|
torch.exp(attn_weights, out=attn_weights)
|
||||||
exp_weights = attn_weights.to(value.dtype)
|
exp_weights = attn_weights.to(value.dtype)
|
||||||
exp_values = torch.bmm(exp_weights, value)
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
@ -96,11 +99,12 @@ def _query_chunk_attention(
|
|||||||
value: Tensor,
|
value: Tensor,
|
||||||
summarize_chunk: SummarizeChunk,
|
summarize_chunk: SummarizeChunk,
|
||||||
kv_chunk_size: int,
|
kv_chunk_size: int,
|
||||||
|
mask,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||||
_, _, v_channels_per_head = value.shape
|
_, _, v_channels_per_head = value.shape
|
||||||
|
|
||||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
|
||||||
key_chunk = dynamic_slice(
|
key_chunk = dynamic_slice(
|
||||||
key_t,
|
key_t,
|
||||||
(0, 0, chunk_idx),
|
(0, 0, chunk_idx),
|
||||||
@ -111,10 +115,13 @@ def _query_chunk_attention(
|
|||||||
(0, chunk_idx, 0),
|
(0, chunk_idx, 0),
|
||||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||||
)
|
)
|
||||||
return summarize_chunk(query, key_chunk, value_chunk)
|
if mask is not None:
|
||||||
|
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
||||||
|
|
||||||
|
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
||||||
|
|
||||||
chunks: List[AttnChunk] = [
|
chunks: List[AttnChunk] = [
|
||||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||||
]
|
]
|
||||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||||
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
value: Tensor,
|
value: Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
upcast_attention: bool,
|
upcast_attention: bool,
|
||||||
|
mask,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
beta=0,
|
beta=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
attn_scores += mask
|
||||||
try:
|
try:
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
|
|||||||
kv_chunk_size_min: Optional[int] = None,
|
kv_chunk_size_min: Optional[int] = None,
|
||||||
use_checkpoint=True,
|
use_checkpoint=True,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
|
mask = None,
|
||||||
):
|
):
|
||||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||||
This is efficient version of attention presented in
|
This is efficient version of attention presented in
|
||||||
@ -209,6 +220,9 @@ def efficient_dot_product_attention(
|
|||||||
if kv_chunk_size_min is not None:
|
if kv_chunk_size_min is not None:
|
||||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||||
|
|
||||||
|
if mask is not None and len(mask.shape) == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
|
||||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||||
return dynamic_slice(
|
return dynamic_slice(
|
||||||
query,
|
query,
|
||||||
@ -216,6 +230,12 @@ def efficient_dot_product_attention(
|
|||||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||||
|
if mask is None:
|
||||||
|
return None
|
||||||
|
chunk = min(query_chunk_size, q_tokens)
|
||||||
|
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
|
|||||||
query=query,
|
query=query,
|
||||||
key_t=key_t,
|
key_t=key_t,
|
||||||
value=value,
|
value=value,
|
||||||
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||||
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
|
|||||||
query=get_query_chunk(i * query_chunk_size),
|
query=get_query_chunk(i * query_chunk_size),
|
||||||
key_t=key_t,
|
key_t=key_t,
|
||||||
value=value,
|
value=value,
|
||||||
|
mask=get_mask_chunk(i * query_chunk_size)
|
||||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||||
], dim=1)
|
], dim=1)
|
||||||
return res
|
return res
|
||||||
|
Loading…
Reference in New Issue
Block a user