mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 04:52:07 +08:00
Use normal ComfyUI attention in ACE-Steps model (#8023)
* Use normal ComfyUI attention in ACE-Steps model * Let optimized_attention handle output reshape for ACE
This commit is contained in:
parent
28f178a840
commit
42da274717
@ -19,6 +19,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
|
|||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
def apply_rotary_emb(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
|
|||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
hidden_states = optimized_attention(
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
).to(query.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user