mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
Fix some OOM issues with split and sub quad attention.
This commit is contained in:
parent
7fbb217d3a
commit
a373367b0c
@ -222,9 +222,14 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
|
element_size = 4
|
||||||
|
else:
|
||||||
|
element_size = q.element_size()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if element_size == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
|
@ -83,7 +83,8 @@ def _summarize_chunk(
|
|||||||
)
|
)
|
||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
attn_weights -= max_score
|
||||||
|
torch.exp(attn_weights, out=attn_weights)
|
||||||
exp_weights = attn_weights.to(value.dtype)
|
exp_weights = attn_weights.to(value.dtype)
|
||||||
exp_values = torch.bmm(exp_weights, value)
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user