Use normal ComfyUI attention in ACE-Steps model (#8023)

* Use normal ComfyUI attention in ACE-Steps model

* Let optimized_attention handle output reshape for ACE
This commit is contained in:
blepping 2025-05-09 11:51:02 -06:00 committed by GitHub
parent 28f178a840
commit 42da274717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,6 +19,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
import comfy.model_management import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
""" """
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def apply_rotary_emb( def apply_rotary_emb(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = optimized_attention(
hidden_states = F.scaled_dot_product_attention( query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ).to(query.dtype)
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)