From 42da274717ff75640e1fb50f88d5c117a9c50630 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Fri, 9 May 2025 11:51:02 -0600 Subject: [PATCH] Use normal ComfyUI attention in ACE-Steps model (#8023) * Use normal ComfyUI attention in ACE-Steps model * Let optimized_attention handle output reshape for ACE --- comfy/ldm/ace/attention.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index 631d1364..f20a0166 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import nn import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention class Attention(nn.Module): def __init__( @@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - def apply_rotary_emb( self, x: torch.Tensor, @@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0: attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + hidden_states = optimized_attention( + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + ).to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states)