Fix sub quadratic attention for SD2 and make it the default optimization.

This commit is contained in:
comfyanonymous 2023-01-25 01:22:43 -05:00
parent 3b38a31cc7
commit 051f472e8f
2 changed files with 60 additions and 26 deletions

View File

@ -175,12 +175,10 @@ class CrossAttentionBirchSan(nn.Module):
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
dtype = query.dtype
# TODO: do we still need to do *everything* in float32, given how we delay the division?
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
# if self.upcast_attention:
# query = query.float()
# key_t = key_t.float()
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
@ -198,7 +196,7 @@ class CrossAttentionBirchSan(nn.Module):
query_chunk_size_x = 1024 * 4
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 1.2) // 1024) * 1024
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
@ -220,6 +218,7 @@ class CrossAttentionBirchSan(nn.Module):
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=self.training,
upcast_attention=upcast_attention,
)
hidden_states = hidden_states.to(dtype)
@ -383,7 +382,14 @@ class OriginalCrossAttention(nn.Module):
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class CrossAttention(CrossAttentionDoggettx):
import sys
if "--use-split-cross-attention" in sys.argv:
print("Using split optimization for cross attention")
class CrossAttention(CrossAttentionDoggettx):
pass
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
class CrossAttention(CrossAttentionBirchSan):
pass
class MemoryEfficientCrossAttention(nn.Module):

View File

@ -53,7 +53,20 @@ def _summarize_chunk(
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
) -> AttnChunk:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
query = query.float()
key_t = key_t.float()
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
else:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
@ -112,7 +125,20 @@ def _get_attention_scores_no_kv_chunking(
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
) -> Tensor:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
query = query.float()
key_t = key_t.float()
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
else:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
@ -137,6 +163,7 @@ def efficient_dot_product_attention(
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@ -170,11 +197,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
scale=scale,
upcast_attention=upcast_attention
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(