mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-07 03:52:10 +08:00
Use normal ComfyUI attention in ACE-Steps model (#8023)
* Use normal ComfyUI attention in ACE-Steps model * Let optimized_attention handle output reshape for ACE
This commit is contained in:
parent
28f178a840
commit
42da274717
@ -19,6 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
Loading…
x
Reference in New Issue
Block a user