remove attention abstraction (#5324)

This commit is contained in:
contentis 2024-10-22 11:02:38 -07:00 committed by GitHub
parent 8ce2a1052c
commit 5a8a48931a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from .. import attention from ..attention import optimized_attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding from .util import timestep_embedding
import comfy.ops import comfy.ops
@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2] return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
return x return x
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x) q, k, v = self.pre_attention(x)
x = optimized_attention( x = optimized_attention(
qkv, num_heads=self.num_heads q, k, v, heads=self.num_heads
) )
x = self.post_attention(x) x = self.post_attention(x)
return x return x
@ -531,8 +529,8 @@ class DismantledBlock(nn.Module):
assert not self.pre_only assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c) qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention( attn = optimized_attention(
qkv, qkv[0], qkv[1], qkv[2],
num_heads=self.attn.num_heads, heads=self.attn.num_heads,
) )
return self.post_attention(attn, *intermediates) return self.post_attention(attn, *intermediates)
@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
qkv = tuple(o) qkv = tuple(o)
attn = optimized_attention( attn = optimized_attention(
qkv, qkv[0], qkv[1], qkv[2],
num_heads=x_block.attn.num_heads, heads=x_block.attn.num_heads,
) )
context_attn, x_attn = ( context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]], attn[:, : context_qkv[0].shape[1]],
@ -642,7 +640,7 @@ class SelfAttentionContext(nn.Module):
def forward(self, x): def forward(self, x):
qkv = self.qkv(x) qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head) q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads) x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
return self.proj(x) return self.proj(x)
class ContextProcessorBlock(nn.Module): class ContextProcessorBlock(nn.Module):