mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
Fix sub quadratic attention for SD2 and make it the default optimization.
This commit is contained in:
parent
3b38a31cc7
commit
051f472e8f
@ -175,13 +175,11 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
# TODO: do we still need to do *everything* in float32, given how we delay the division?
|
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||||
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
|
if upcast_attention:
|
||||||
# if self.upcast_attention:
|
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||||
# query = query.float()
|
else:
|
||||||
# key_t = key_t.float()
|
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||||
|
|
||||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
|
||||||
batch_x_heads, q_tokens, _ = query.shape
|
batch_x_heads, q_tokens, _ = query.shape
|
||||||
_, _, k_tokens = key_t.shape
|
_, _, k_tokens = key_t.shape
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
@ -198,7 +196,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
|
|
||||||
query_chunk_size_x = 1024 * 4
|
query_chunk_size_x = 1024 * 4
|
||||||
kv_chunk_size_min_x = None
|
kv_chunk_size_min_x = None
|
||||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 1.2) // 1024) * 1024
|
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||||
if kv_chunk_size_x < 1024:
|
if kv_chunk_size_x < 1024:
|
||||||
kv_chunk_size_x = None
|
kv_chunk_size_x = None
|
||||||
|
|
||||||
@ -220,6 +218,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
kv_chunk_size=kv_chunk_size,
|
kv_chunk_size=kv_chunk_size,
|
||||||
kv_chunk_size_min=kv_chunk_size_min,
|
kv_chunk_size_min=kv_chunk_size_min,
|
||||||
use_checkpoint=self.training,
|
use_checkpoint=self.training,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
@ -383,8 +382,15 @@ class OriginalCrossAttention(nn.Module):
|
|||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttention(CrossAttentionDoggettx):
|
import sys
|
||||||
pass
|
if "--use-split-cross-attention" in sys.argv:
|
||||||
|
print("Using split optimization for cross attention")
|
||||||
|
class CrossAttention(CrossAttentionDoggettx):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
|
class CrossAttention(CrossAttentionBirchSan):
|
||||||
|
pass
|
||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
|
@ -53,14 +53,27 @@ def _summarize_chunk(
|
|||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
|
upcast_attention: bool,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
attn_weights = torch.baddbmm(
|
if upcast_attention:
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
query,
|
query = query.float()
|
||||||
key_t,
|
key_t = key_t.float()
|
||||||
alpha=scale,
|
attn_weights = torch.baddbmm(
|
||||||
beta=0,
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
)
|
query,
|
||||||
|
key_t,
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.baddbmm(
|
||||||
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
|
query,
|
||||||
|
key_t,
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
exp_weights = torch.exp(attn_weights - max_score)
|
exp_weights = torch.exp(attn_weights - max_score)
|
||||||
@ -112,14 +125,27 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
|
upcast_attention: bool,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_scores = torch.baddbmm(
|
if upcast_attention:
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
query,
|
query = query.float()
|
||||||
key_t,
|
key_t = key_t.float()
|
||||||
alpha=scale,
|
attn_scores = torch.baddbmm(
|
||||||
beta=0,
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
)
|
query,
|
||||||
|
key_t,
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_scores = torch.baddbmm(
|
||||||
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
|
query,
|
||||||
|
key_t,
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||||
@ -137,6 +163,7 @@ def efficient_dot_product_attention(
|
|||||||
kv_chunk_size: Optional[int] = None,
|
kv_chunk_size: Optional[int] = None,
|
||||||
kv_chunk_size_min: Optional[int] = None,
|
kv_chunk_size_min: Optional[int] = None,
|
||||||
use_checkpoint=True,
|
use_checkpoint=True,
|
||||||
|
upcast_attention=False,
|
||||||
):
|
):
|
||||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||||
This is efficient version of attention presented in
|
This is efficient version of attention presented in
|
||||||
@ -170,11 +197,12 @@ def efficient_dot_product_attention(
|
|||||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||||
)
|
)
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
_get_attention_scores_no_kv_chunking,
|
_get_attention_scores_no_kv_chunking,
|
||||||
scale=scale
|
scale=scale,
|
||||||
|
upcast_attention=upcast_attention
|
||||||
) if k_tokens <= kv_chunk_size else (
|
) if k_tokens <= kv_chunk_size else (
|
||||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||||
partial(
|
partial(
|
||||||
|
Loading…
Reference in New Issue
Block a user