mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
remove attention abstraction (#5324)
This commit is contained in:
parent
8ce2a1052c
commit
5a8a48931a
@ -5,7 +5,7 @@ from typing import Dict, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .. import attention
|
from ..attention import optimized_attention
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .util import timestep_embedding
|
from .util import timestep_embedding
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
|
|||||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||||
return qkv[0], qkv[1], qkv[2]
|
return qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
def optimized_attention(qkv, num_heads):
|
|
||||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
qkv = self.pre_attention(x)
|
q, k, v = self.pre_attention(x)
|
||||||
x = optimized_attention(
|
x = optimized_attention(
|
||||||
qkv, num_heads=self.num_heads
|
q, k, v, heads=self.num_heads
|
||||||
)
|
)
|
||||||
x = self.post_attention(x)
|
x = self.post_attention(x)
|
||||||
return x
|
return x
|
||||||
@ -531,8 +529,8 @@ class DismantledBlock(nn.Module):
|
|||||||
assert not self.pre_only
|
assert not self.pre_only
|
||||||
qkv, intermediates = self.pre_attention(x, c)
|
qkv, intermediates = self.pre_attention(x, c)
|
||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv,
|
qkv[0], qkv[1], qkv[2],
|
||||||
num_heads=self.attn.num_heads,
|
heads=self.attn.num_heads,
|
||||||
)
|
)
|
||||||
return self.post_attention(attn, *intermediates)
|
return self.post_attention(attn, *intermediates)
|
||||||
|
|
||||||
@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
qkv = tuple(o)
|
qkv = tuple(o)
|
||||||
|
|
||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv,
|
qkv[0], qkv[1], qkv[2],
|
||||||
num_heads=x_block.attn.num_heads,
|
heads=x_block.attn.num_heads,
|
||||||
)
|
)
|
||||||
context_attn, x_attn = (
|
context_attn, x_attn = (
|
||||||
attn[:, : context_qkv[0].shape[1]],
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
@ -642,7 +640,7 @@ class SelfAttentionContext(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
q, k, v = split_qkv(qkv, self.dim_head)
|
q, k, v = split_qkv(qkv, self.dim_head)
|
||||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||||
return self.proj(x)
|
return self.proj(x)
|
||||||
|
|
||||||
class ContextProcessorBlock(nn.Module):
|
class ContextProcessorBlock(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user