diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index c0ee0cf2..45396d86 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -76,7 +76,8 @@ def _summarize_chunk( ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() - exp_weights = torch.exp(attn_weights - max_score) + torch.exp(attn_weights - max_score, out=attn_weights) + exp_weights = attn_weights exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)