2023-01-03 06:53:32 +00:00
import math
import torch
import torch . nn . functional as F
from torch import nn , einsum
from einops import rearrange , repeat
from typing import Optional , Any
2024-03-11 20:24:47 +00:00
import logging
2023-11-24 00:41:33 +00:00
2024-05-18 14:11:44 +00:00
from . diffusionmodules . util import AlphaBlender , timestep_embedding
2023-01-03 06:53:32 +00:00
from . sub_quadratic_attention import efficient_dot_product_attention
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-03-03 08:27:33 +00:00
2023-03-13 15:36:48 +00:00
if model_management . xformers_enabled ( ) :
2023-01-03 06:53:32 +00:00
import xformers
import xformers . ops
2023-07-02 15:57:36 +00:00
from comfy . cli_args import args
2023-08-18 20:32:23 +00:00
import comfy . ops
2023-12-12 04:27:13 +00:00
ops = comfy . ops . disable_weight_init
2023-08-18 20:32:23 +00:00
2024-05-16 08:09:41 +00:00
def get_attn_precision ( attn_precision ) :
if args . dont_upcast_attention :
return None
if attn_precision is None and args . force_upcast_attention :
return torch . float32
return attn_precision
2023-01-03 06:53:32 +00:00
def exists ( val ) :
return val is not None
def uniq ( arr ) :
return { el : True for el in arr } . keys ( )
def default ( val , d ) :
if exists ( val ) :
return val
2023-07-19 01:36:35 +00:00
return d
2023-01-03 06:53:32 +00:00
def max_neg_value ( t ) :
return - torch . finfo ( t . dtype ) . max
def init_ ( tensor ) :
dim = tensor . shape [ - 1 ]
std = 1 / math . sqrt ( dim )
tensor . uniform_ ( - std , std )
return tensor
# feedforward
class GEGLU ( nn . Module ) :
2023-12-12 04:27:13 +00:00
def __init__ ( self , dim_in , dim_out , dtype = None , device = None , operations = ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
2023-08-18 06:46:11 +00:00
self . proj = operations . Linear ( dim_in , dim_out * 2 , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
def forward ( self , x ) :
x , gate = self . proj ( x ) . chunk ( 2 , dim = - 1 )
return x * F . gelu ( gate )
class FeedForward ( nn . Module ) :
2023-12-12 04:27:13 +00:00
def __init__ ( self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0. , dtype = None , device = None , operations = ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
inner_dim = int ( dim * mult )
dim_out = default ( dim_out , dim )
project_in = nn . Sequential (
2023-08-18 06:46:11 +00:00
operations . Linear ( dim , inner_dim , dtype = dtype , device = device ) ,
2023-01-03 06:53:32 +00:00
nn . GELU ( )
2023-08-18 06:46:11 +00:00
) if not glu else GEGLU ( dim , inner_dim , dtype = dtype , device = device , operations = operations )
2023-01-03 06:53:32 +00:00
self . net = nn . Sequential (
project_in ,
nn . Dropout ( dropout ) ,
2023-08-18 06:46:11 +00:00
operations . Linear ( inner_dim , dim_out , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
)
def forward ( self , x ) :
return self . net ( x )
2023-07-29 18:51:56 +00:00
def Normalize ( in_channels , dtype = None , device = None ) :
return torch . nn . GroupNorm ( num_groups = 32 , num_channels = in_channels , eps = 1e-6 , affine = True , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
2024-05-14 16:47:31 +00:00
def attention_basic ( q , k , v , heads , mask = None , attn_precision = None ) :
2024-05-16 08:09:41 +00:00
attn_precision = get_attn_precision ( attn_precision )
2023-10-22 07:59:53 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
scale = dim_head * * - 0.5
2023-10-11 19:47:53 +00:00
h = heads
2023-10-22 07:59:53 +00:00
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( b , - 1 , heads , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b * heads , - 1 , dim_head )
. contiguous ( ) ,
( q , k , v ) ,
)
2023-10-11 19:47:53 +00:00
# force cast to fp32 to avoid overflowing
2024-05-14 16:47:31 +00:00
if attn_precision == torch . float32 :
2023-12-15 06:28:16 +00:00
sim = einsum ( ' b i d, b j d -> b i j ' , q . float ( ) , k . float ( ) ) * scale
2023-10-11 19:47:53 +00:00
else :
sim = einsum ( ' b i d, b j d -> b i j ' , q , k ) * scale
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
del q , k
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
if exists ( mask ) :
2023-12-06 20:55:09 +00:00
if mask . dtype == torch . bool :
mask = rearrange ( mask , ' b ... -> b (...) ' ) #TODO: check if this bool part matches pytorch attention
max_neg_value = - torch . finfo ( sim . dtype ) . max
mask = repeat ( mask , ' b j -> (b h) () j ' , h = h )
sim . masked_fill_ ( ~ mask , max_neg_value )
else :
2024-02-17 20:22:21 +00:00
if len ( mask . shape ) == 2 :
bs = 1
else :
bs = mask . shape [ 0 ]
2024-02-17 21:15:18 +00:00
mask = mask . reshape ( bs , - 1 , mask . shape [ - 2 ] , mask . shape [ - 1 ] ) . expand ( b , heads , - 1 , - 1 ) . reshape ( - 1 , mask . shape [ - 2 ] , mask . shape [ - 1 ] )
2024-02-17 17:13:13 +00:00
sim . add_ ( mask )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
# attention, what we cannot get enough of
sim = sim . softmax ( dim = - 1 )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
out = einsum ( ' b i j, b j d -> b i d ' , sim . to ( v . dtype ) , v )
2023-10-22 07:59:53 +00:00
out = (
out . unsqueeze ( 0 )
. reshape ( b , heads , - 1 , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b , - 1 , heads * dim_head )
)
2023-10-11 19:47:53 +00:00
return out
2023-01-03 06:53:32 +00:00
2024-05-14 16:47:31 +00:00
def attention_sub_quad ( query , key , value , heads , mask = None , attn_precision = None ) :
2024-05-16 08:09:41 +00:00
attn_precision = get_attn_precision ( attn_precision )
2023-10-22 07:51:29 +00:00
b , _ , dim_head = query . shape
dim_head / / = heads
scale = dim_head * * - 0.5
query = query . unsqueeze ( 3 ) . reshape ( b , - 1 , heads , dim_head ) . permute ( 0 , 2 , 1 , 3 ) . reshape ( b * heads , - 1 , dim_head )
value = value . unsqueeze ( 3 ) . reshape ( b , - 1 , heads , dim_head ) . permute ( 0 , 2 , 1 , 3 ) . reshape ( b * heads , - 1 , dim_head )
key = key . unsqueeze ( 3 ) . reshape ( b , - 1 , heads , dim_head ) . permute ( 0 , 2 , 3 , 1 ) . reshape ( b * heads , dim_head , - 1 )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
dtype = query . dtype
2024-05-14 16:47:31 +00:00
upcast_attention = attn_precision == torch . float32 and query . dtype != torch . float32
2023-10-11 19:47:53 +00:00
if upcast_attention :
bytes_per_token = torch . finfo ( torch . float32 ) . bits / / 8
else :
bytes_per_token = torch . finfo ( query . dtype ) . bits / / 8
batch_x_heads , q_tokens , _ = query . shape
2023-10-22 07:51:29 +00:00
_ , _ , k_tokens = key . shape
2023-10-11 19:47:53 +00:00
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
mem_free_total , mem_free_torch = model_management . get_free_memory ( query . device , True )
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
kv_chunk_size_min = None
2023-10-30 19:29:45 +00:00
kv_chunk_size = None
query_chunk_size = None
for x in [ 4096 , 2048 , 1024 , 512 , 256 ] :
count = mem_free_total / ( batch_x_heads * bytes_per_token * x * 4.0 )
if count > = k_tokens :
kv_chunk_size = k_tokens
query_chunk_size = x
break
if query_chunk_size is None :
query_chunk_size = 512
2023-10-11 19:47:53 +00:00
2024-02-17 17:13:13 +00:00
if mask is not None :
2024-02-17 20:22:21 +00:00
if len ( mask . shape ) == 2 :
bs = 1
else :
bs = mask . shape [ 0 ]
2024-02-17 21:15:18 +00:00
mask = mask . reshape ( bs , - 1 , mask . shape [ - 2 ] , mask . shape [ - 1 ] ) . expand ( b , heads , - 1 , - 1 ) . reshape ( - 1 , mask . shape [ - 2 ] , mask . shape [ - 1 ] )
2024-02-17 17:13:13 +00:00
2023-10-11 19:47:53 +00:00
hidden_states = efficient_dot_product_attention (
query ,
2023-10-22 07:51:29 +00:00
key ,
2023-10-11 19:47:53 +00:00
value ,
query_chunk_size = query_chunk_size ,
kv_chunk_size = kv_chunk_size ,
kv_chunk_size_min = kv_chunk_size_min ,
use_checkpoint = False ,
upcast_attention = upcast_attention ,
2024-01-07 09:13:58 +00:00
mask = mask ,
2023-10-11 19:47:53 +00:00
)
hidden_states = hidden_states . to ( dtype )
hidden_states = hidden_states . unflatten ( 0 , ( - 1 , heads ) ) . transpose ( 1 , 2 ) . flatten ( start_dim = 2 )
return hidden_states
2024-05-14 16:47:31 +00:00
def attention_split ( q , k , v , heads , mask = None , attn_precision = None ) :
2024-05-16 08:09:41 +00:00
attn_precision = get_attn_precision ( attn_precision )
2023-10-22 07:51:29 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
scale = dim_head * * - 0.5
2023-10-11 19:47:53 +00:00
h = heads
2023-10-22 07:51:29 +00:00
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( b , - 1 , heads , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b * heads , - 1 , dim_head )
. contiguous ( ) ,
( q , k , v ) ,
)
2023-10-11 19:47:53 +00:00
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
mem_free_total = model_management . get_free_memory ( q . device )
2024-05-14 16:47:31 +00:00
if attn_precision == torch . float32 :
2023-10-26 00:17:28 +00:00
element_size = 4
2024-05-14 16:47:31 +00:00
upcast = True
2023-10-26 00:17:28 +00:00
else :
element_size = q . element_size ( )
2024-05-14 16:47:31 +00:00
upcast = False
2023-10-26 00:17:28 +00:00
2023-10-11 19:47:53 +00:00
gb = 1024 * * 3
2023-10-26 00:17:28 +00:00
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * element_size
2023-10-30 17:14:11 +00:00
modifier = 3
2023-10-11 19:47:53 +00:00
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64 :
max_res = math . floor ( math . sqrt ( math . sqrt ( mem_free_total / 2.5 ) ) / 8 ) * 64
raise RuntimeError ( f ' Not enough memory, use lower resolution (max approx. { max_res } x { max_res } ). '
f ' Need: { mem_required / 64 / gb : 0.1f } GB free, Have: { mem_free_total / gb : 0.1f } GB free ' )
2024-02-17 17:13:13 +00:00
if mask is not None :
2024-02-17 20:22:21 +00:00
if len ( mask . shape ) == 2 :
bs = 1
else :
bs = mask . shape [ 0 ]
2024-02-17 21:15:18 +00:00
mask = mask . reshape ( bs , - 1 , mask . shape [ - 2 ] , mask . shape [ - 1 ] ) . expand ( b , heads , - 1 , - 1 ) . reshape ( - 1 , mask . shape [ - 2 ] , mask . shape [ - 1 ] )
2024-02-17 17:13:13 +00:00
2023-10-11 19:47:53 +00:00
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True :
try :
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
2024-05-14 16:47:31 +00:00
if upcast :
2023-10-11 19:47:53 +00:00
with torch . autocast ( enabled = False , device_type = ' cuda ' ) :
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] . float ( ) , k . float ( ) ) * scale
2023-01-03 06:53:32 +00:00
else :
2023-10-11 19:47:53 +00:00
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] , k ) * scale
2024-01-06 18:16:48 +00:00
if mask is not None :
if len ( mask . shape ) == 2 :
s1 + = mask [ i : end ]
else :
s1 + = mask [ : , i : end ]
2023-10-11 19:47:53 +00:00
s2 = s1 . softmax ( dim = - 1 ) . to ( v . dtype )
del s1
2023-10-30 17:14:11 +00:00
first_op_done = True
2023-10-11 19:47:53 +00:00
r1 [ : , i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v )
del s2
break
except model_management . OOM_EXCEPTION as e :
if first_op_done == False :
model_management . soft_empty_cache ( True )
if cleared_cache == False :
cleared_cache = True
2024-03-11 20:24:47 +00:00
logging . warning ( " out of memory error, emptying cache and trying again " )
2023-10-11 19:47:53 +00:00
continue
steps * = 2
if steps > 64 :
2023-01-03 06:53:32 +00:00
raise e
2024-03-11 20:24:47 +00:00
logging . warning ( " out of memory error, increasing steps and trying again {} " . format ( steps ) )
2023-10-11 19:47:53 +00:00
else :
raise e
del q , k , v
2023-10-22 07:51:29 +00:00
r1 = (
r1 . unsqueeze ( 0 )
. reshape ( b , heads , - 1 , dim_head )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( b , - 1 , heads * dim_head )
)
return r1
2023-10-11 19:47:53 +00:00
2023-11-24 08:55:35 +00:00
BROKEN_XFORMERS = False
try :
x_vers = xformers . __version__
2024-05-01 01:23:40 +00:00
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers . startswith ( " 0.0.2 " ) and not x_vers . startswith ( " 0.0.20 " )
2023-11-24 08:55:35 +00:00
except :
pass
2024-05-14 16:47:31 +00:00
def attention_xformers ( q , k , v , heads , mask = None , attn_precision = None ) :
2023-10-21 17:23:03 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
2024-05-21 17:55:49 +00:00
disabled_xformers = False
2023-11-24 08:55:35 +00:00
if BROKEN_XFORMERS :
if b * heads > 65535 :
2024-05-21 17:55:49 +00:00
disabled_xformers = True
if not disabled_xformers :
if torch . jit . is_tracing ( ) or torch . jit . is_scripting ( ) :
disabled_xformers = True
if disabled_xformers :
return attention_pytorch ( q , k , v , heads , mask )
2023-10-21 17:23:03 +00:00
2023-10-11 19:47:53 +00:00
q , k , v = map (
2024-05-18 13:36:26 +00:00
lambda t : t . reshape ( b , - 1 , heads , dim_head ) ,
2023-10-11 19:47:53 +00:00
( q , k , v ) ,
)
2024-01-06 09:33:03 +00:00
if mask is not None :
pad = 8 - q . shape [ 1 ] % 8
mask_out = torch . empty ( [ q . shape [ 0 ] , q . shape [ 1 ] , q . shape [ 1 ] + pad ] , dtype = q . dtype , device = q . device )
mask_out [ : , : , : mask . shape [ - 1 ] ] = mask
mask = mask_out [ : , : , : mask . shape [ - 1 ] ]
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = mask )
2023-10-11 19:47:53 +00:00
out = (
2024-05-18 13:36:26 +00:00
out . reshape ( b , - 1 , heads * dim_head )
2023-10-11 19:47:53 +00:00
)
return out
2024-05-14 16:47:31 +00:00
def attention_pytorch ( q , k , v , heads , mask = None , attn_precision = None ) :
2023-10-11 19:47:53 +00:00
b , _ , dim_head = q . shape
dim_head / / = heads
q , k , v = map (
lambda t : t . view ( b , - 1 , heads , dim_head ) . transpose ( 1 , 2 ) ,
( q , k , v ) ,
)
2023-10-12 00:24:17 +00:00
out = torch . nn . functional . scaled_dot_product_attention ( q , k , v , attn_mask = mask , dropout_p = 0.0 , is_causal = False )
2023-10-11 19:47:53 +00:00
out = (
out . transpose ( 1 , 2 ) . reshape ( b , - 1 , heads * dim_head )
)
return out
2023-10-16 06:31:24 +00:00
2023-10-11 19:47:53 +00:00
optimized_attention = attention_basic
2023-01-03 06:53:32 +00:00
2023-10-11 19:47:53 +00:00
if model_management . xformers_enabled ( ) :
2024-03-11 20:24:47 +00:00
logging . info ( " Using xformers cross attention " )
2023-10-11 19:47:53 +00:00
optimized_attention = attention_xformers
elif model_management . pytorch_attention_enabled ( ) :
2024-03-11 20:24:47 +00:00
logging . info ( " Using pytorch cross attention " )
2023-10-11 19:47:53 +00:00
optimized_attention = attention_pytorch
else :
if args . use_split_cross_attention :
2024-03-11 20:24:47 +00:00
logging . info ( " Using split optimization for cross attention " )
2023-10-11 19:47:53 +00:00
optimized_attention = attention_split
else :
2024-03-11 20:24:47 +00:00
logging . info ( " Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention " )
2023-10-11 19:47:53 +00:00
optimized_attention = attention_sub_quad
2023-01-03 06:53:32 +00:00
2024-01-07 18:52:08 +00:00
optimized_attention_masked = optimized_attention
def optimized_attention_for_device ( device , mask = False , small_input = False ) :
2024-01-09 18:46:52 +00:00
if small_input :
if model_management . pytorch_attention_enabled ( ) :
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else :
return attention_basic
2024-01-07 18:52:08 +00:00
if device == torch . device ( " cpu " ) :
return attention_sub_quad
2023-10-16 06:31:24 +00:00
2023-12-06 20:55:09 +00:00
if mask :
return optimized_attention_masked
return optimized_attention
2023-02-22 03:16:13 +00:00
class CrossAttention ( nn . Module ) :
2024-05-14 19:18:00 +00:00
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , attn_precision = None , dtype = None , device = None , operations = ops ) :
2023-03-02 22:01:20 +00:00
super ( ) . __init__ ( )
inner_dim = dim_head * heads
context_dim = default ( context_dim , query_dim )
2024-05-14 19:18:00 +00:00
self . attn_precision = attn_precision
2023-03-02 22:01:20 +00:00
self . heads = heads
self . dim_head = dim_head
2023-08-18 06:46:11 +00:00
self . to_q = operations . Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operations . Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operations . Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
2023-03-02 22:01:20 +00:00
2023-08-18 06:46:11 +00:00
self . to_out = nn . Sequential ( operations . Linear ( inner_dim , query_dim , dtype = dtype , device = device ) , nn . Dropout ( dropout ) )
2023-03-02 22:01:20 +00:00
2023-04-23 16:35:25 +00:00
def forward ( self , x , context = None , value = None , mask = None ) :
2023-03-02 22:01:20 +00:00
q = self . to_q ( x )
context = default ( context , x )
k = self . to_k ( context )
2023-04-23 16:35:25 +00:00
if value is not None :
v = self . to_v ( value )
del value
else :
v = self . to_v ( context )
2023-03-02 22:01:20 +00:00
2023-10-16 06:31:24 +00:00
if mask is None :
2024-05-14 19:18:00 +00:00
out = optimized_attention ( q , k , v , self . heads , attn_precision = self . attn_precision )
2023-10-16 06:31:24 +00:00
else :
2024-05-14 19:18:00 +00:00
out = optimized_attention_masked ( q , k , v , self . heads , mask , attn_precision = self . attn_precision )
2023-03-02 22:01:20 +00:00
return self . to_out ( out )
2023-03-05 19:14:54 +00:00
2023-01-03 06:53:32 +00:00
class BasicTransformerBlock ( nn . Module ) :
2023-11-24 00:41:33 +00:00
def __init__ ( self , dim , n_heads , d_head , dropout = 0. , context_dim = None , gated_ff = True , checkpoint = True , ff_in = False , inner_dim = None ,
2024-05-14 19:18:00 +00:00
disable_self_attn = False , disable_temporal_crossattention = False , switch_temporal_ca_to_sa = False , attn_precision = None , dtype = None , device = None , operations = ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
2023-11-24 00:41:33 +00:00
self . ff_in = ff_in or inner_dim is not None
if inner_dim is None :
inner_dim = dim
self . is_res = inner_dim == dim
2024-05-14 22:02:27 +00:00
self . attn_precision = attn_precision
2023-11-24 00:41:33 +00:00
if self . ff_in :
2023-12-05 17:48:41 +00:00
self . norm_in = operations . LayerNorm ( dim , dtype = dtype , device = device )
2023-11-24 00:41:33 +00:00
self . ff_in = FeedForward ( dim , dim_out = inner_dim , dropout = dropout , glu = gated_ff , dtype = dtype , device = device , operations = operations )
2023-01-03 06:53:32 +00:00
self . disable_self_attn = disable_self_attn
2023-11-24 00:41:33 +00:00
self . attn1 = CrossAttention ( query_dim = inner_dim , heads = n_heads , dim_head = d_head , dropout = dropout ,
2024-05-14 22:02:27 +00:00
context_dim = context_dim if self . disable_self_attn else None , attn_precision = self . attn_precision , dtype = dtype , device = device , operations = operations ) # is a self-attention if not self.disable_self_attn
2023-11-24 00:41:33 +00:00
self . ff = FeedForward ( inner_dim , dim_out = dim , dropout = dropout , glu = gated_ff , dtype = dtype , device = device , operations = operations )
if disable_temporal_crossattention :
if switch_temporal_ca_to_sa :
raise ValueError
else :
self . attn2 = None
else :
context_dim_attn2 = None
if not switch_temporal_ca_to_sa :
context_dim_attn2 = context_dim
self . attn2 = CrossAttention ( query_dim = inner_dim , context_dim = context_dim_attn2 ,
2024-05-14 22:02:27 +00:00
heads = n_heads , dim_head = d_head , dropout = dropout , attn_precision = self . attn_precision , dtype = dtype , device = device , operations = operations ) # is self-attn if context is none
2023-12-04 08:12:18 +00:00
self . norm2 = operations . LayerNorm ( inner_dim , dtype = dtype , device = device )
2023-11-24 00:41:33 +00:00
2023-12-04 08:12:18 +00:00
self . norm1 = operations . LayerNorm ( inner_dim , dtype = dtype , device = device )
self . norm3 = operations . LayerNorm ( inner_dim , dtype = dtype , device = device )
2023-06-24 00:17:45 +00:00
self . n_heads = n_heads
self . d_head = d_head
2023-11-24 00:41:33 +00:00
self . switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
2023-01-03 06:53:32 +00:00
2023-03-31 17:04:39 +00:00
def forward ( self , x , context = None , transformer_options = { } ) :
2023-06-19 02:58:22 +00:00
extra_options = { }
2023-11-26 08:43:02 +00:00
block = transformer_options . get ( " block " , None )
block_index = transformer_options . get ( " block_index " , 0 )
2023-11-26 08:13:56 +00:00
transformer_patches = { }
transformer_patches_replace = { }
for k in transformer_options :
if k == " patches " :
transformer_patches = transformer_options [ k ]
elif k == " patches_replace " :
transformer_patches_replace = transformer_options [ k ]
else :
extra_options [ k ] = transformer_options [ k ]
2023-04-19 13:36:19 +00:00
2023-06-24 00:17:45 +00:00
extra_options [ " n_heads " ] = self . n_heads
extra_options [ " dim_head " ] = self . d_head
2024-05-14 22:02:27 +00:00
extra_options [ " attn_precision " ] = self . attn_precision
2023-06-24 00:17:45 +00:00
2023-11-24 00:41:33 +00:00
if self . ff_in :
x_skip = x
x = self . ff_in ( self . norm_in ( x ) )
if self . is_res :
x + = x_skip
2023-03-31 21:19:58 +00:00
n = self . norm1 ( x )
2023-04-23 16:35:25 +00:00
if self . disable_self_attn :
context_attn1 = context
else :
context_attn1 = None
value_attn1 = None
if " attn1_patch " in transformer_patches :
patch = transformer_patches [ " attn1_patch " ]
if context_attn1 is None :
context_attn1 = n
value_attn1 = context_attn1
for p in patch :
2023-06-19 02:58:22 +00:00
n , context_attn1 , value_attn1 = p ( n , context_attn1 , value_attn1 , extra_options )
2023-04-23 16:35:25 +00:00
2023-06-24 15:02:38 +00:00
if block is not None :
transformer_block = ( block [ 0 ] , block [ 1 ] , block_index )
else :
transformer_block = None
2023-06-24 00:17:45 +00:00
attn1_replace_patch = transformer_patches_replace . get ( " attn1 " , { } )
block_attn1 = transformer_block
if block_attn1 not in attn1_replace_patch :
block_attn1 = block
if block_attn1 in attn1_replace_patch :
if context_attn1 is None :
context_attn1 = n
value_attn1 = n
n = self . attn1 . to_q ( n )
context_attn1 = self . attn1 . to_k ( context_attn1 )
value_attn1 = self . attn1 . to_v ( value_attn1 )
n = attn1_replace_patch [ block_attn1 ] ( n , context_attn1 , value_attn1 , extra_options )
n = self . attn1 . to_out ( n )
2023-03-31 21:19:58 +00:00
else :
2023-04-23 16:35:25 +00:00
n = self . attn1 ( n , context = context_attn1 , value = value_attn1 )
2023-03-31 21:19:58 +00:00
2023-06-24 00:17:45 +00:00
if " attn1_output_patch " in transformer_patches :
patch = transformer_patches [ " attn1_output_patch " ]
for p in patch :
n = p ( n , extra_options )
2023-03-31 21:19:58 +00:00
x + = n
2023-04-19 13:36:19 +00:00
if " middle_patch " in transformer_patches :
patch = transformer_patches [ " middle_patch " ]
for p in patch :
2023-06-19 02:58:22 +00:00
x = p ( x , extra_options )
2023-04-19 13:36:19 +00:00
2023-11-24 00:41:33 +00:00
if self . attn2 is not None :
n = self . norm2 ( x )
if self . switch_temporal_ca_to_sa :
context_attn2 = n
else :
context_attn2 = context
value_attn2 = None
if " attn2_patch " in transformer_patches :
patch = transformer_patches [ " attn2_patch " ]
2023-06-24 00:17:45 +00:00
value_attn2 = context_attn2
2023-11-24 00:41:33 +00:00
for p in patch :
n , context_attn2 , value_attn2 = p ( n , context_attn2 , value_attn2 , extra_options )
attn2_replace_patch = transformer_patches_replace . get ( " attn2 " , { } )
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch :
block_attn2 = block
if block_attn2 in attn2_replace_patch :
if value_attn2 is None :
value_attn2 = context_attn2
n = self . attn2 . to_q ( n )
context_attn2 = self . attn2 . to_k ( context_attn2 )
value_attn2 = self . attn2 . to_v ( value_attn2 )
n = attn2_replace_patch [ block_attn2 ] ( n , context_attn2 , value_attn2 , extra_options )
n = self . attn2 . to_out ( n )
else :
n = self . attn2 ( n , context = context_attn2 , value = value_attn2 )
2023-03-31 21:19:58 +00:00
2023-06-19 02:58:22 +00:00
if " attn2_output_patch " in transformer_patches :
patch = transformer_patches [ " attn2_output_patch " ]
for p in patch :
n = p ( n , extra_options )
2023-03-31 21:19:58 +00:00
x + = n
2023-11-24 00:41:33 +00:00
if self . is_res :
x_skip = x
x = self . ff ( self . norm3 ( x ) )
if self . is_res :
x + = x_skip
2023-01-03 06:53:32 +00:00
return x
class SpatialTransformer ( nn . Module ) :
"""
Transformer block for image - like data .
First , project the input ( aka embedding )
and reshape to b , t , d .
Then apply standard transformer action .
Finally , reshape to image
NEW : use_linear for more efficiency instead of the 1 x1 convs
"""
def __init__ ( self , in_channels , n_heads , d_head ,
depth = 1 , dropout = 0. , context_dim = None ,
disable_self_attn = False , use_linear = False ,
2024-05-14 19:18:00 +00:00
use_checkpoint = True , attn_precision = None , dtype = None , device = None , operations = ops ) :
2023-01-03 06:53:32 +00:00
super ( ) . __init__ ( )
if exists ( context_dim ) and not isinstance ( context_dim , list ) :
2023-06-22 17:03:50 +00:00
context_dim = [ context_dim ] * depth
2023-01-03 06:53:32 +00:00
self . in_channels = in_channels
inner_dim = n_heads * d_head
2023-12-04 08:12:18 +00:00
self . norm = operations . GroupNorm ( num_groups = 32 , num_channels = in_channels , eps = 1e-6 , affine = True , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
if not use_linear :
2023-08-18 06:46:11 +00:00
self . proj_in = operations . Conv2d ( in_channels ,
2023-01-03 06:53:32 +00:00
inner_dim ,
kernel_size = 1 ,
stride = 1 ,
2023-07-29 18:51:56 +00:00
padding = 0 , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
else :
2023-08-18 06:46:11 +00:00
self . proj_in = operations . Linear ( in_channels , inner_dim , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
self . transformer_blocks = nn . ModuleList (
[ BasicTransformerBlock ( inner_dim , n_heads , d_head , dropout = dropout , context_dim = context_dim [ d ] ,
2024-05-14 19:18:00 +00:00
disable_self_attn = disable_self_attn , checkpoint = use_checkpoint , attn_precision = attn_precision , dtype = dtype , device = device , operations = operations )
2023-01-03 06:53:32 +00:00
for d in range ( depth ) ]
)
if not use_linear :
2023-08-18 06:46:11 +00:00
self . proj_out = operations . Conv2d ( inner_dim , in_channels ,
2023-01-03 06:53:32 +00:00
kernel_size = 1 ,
stride = 1 ,
2023-07-29 18:51:56 +00:00
padding = 0 , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
else :
2023-08-18 06:46:11 +00:00
self . proj_out = operations . Linear ( in_channels , inner_dim , dtype = dtype , device = device )
2023-01-03 06:53:32 +00:00
self . use_linear = use_linear
2023-03-31 17:04:39 +00:00
def forward ( self , x , context = None , transformer_options = { } ) :
2023-01-03 06:53:32 +00:00
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance ( context , list ) :
2023-06-22 17:03:50 +00:00
context = [ context ] * len ( self . transformer_blocks )
2023-01-03 06:53:32 +00:00
b , c , h , w = x . shape
x_in = x
x = self . norm ( x )
if not self . use_linear :
x = self . proj_in ( x )
2024-05-20 12:19:54 +00:00
x = x . movedim ( 1 , 3 ) . flatten ( 1 , 2 ) . contiguous ( )
2023-01-03 06:53:32 +00:00
if self . use_linear :
x = self . proj_in ( x )
for i , block in enumerate ( self . transformer_blocks ) :
2023-06-19 02:58:22 +00:00
transformer_options [ " block_index " ] = i
2023-03-31 17:04:39 +00:00
x = block ( x , context = context [ i ] , transformer_options = transformer_options )
2023-01-03 06:53:32 +00:00
if self . use_linear :
x = self . proj_out ( x )
2024-05-20 12:19:54 +00:00
x = x . reshape ( x . shape [ 0 ] , h , w , x . shape [ - 1 ] ) . movedim ( 3 , 1 ) . contiguous ( )
2023-01-03 06:53:32 +00:00
if not self . use_linear :
x = self . proj_out ( x )
return x + x_in
2023-11-24 00:41:33 +00:00
class SpatialVideoTransformer ( SpatialTransformer ) :
def __init__ (
self ,
in_channels ,
n_heads ,
d_head ,
depth = 1 ,
dropout = 0.0 ,
use_linear = False ,
context_dim = None ,
use_spatial_context = False ,
timesteps = None ,
merge_strategy : str = " fixed " ,
merge_factor : float = 0.5 ,
time_context_dim = None ,
ff_in = False ,
checkpoint = False ,
time_depth = 1 ,
disable_self_attn = False ,
disable_temporal_crossattention = False ,
max_time_embed_period : int = 10000 ,
2024-05-14 19:18:00 +00:00
attn_precision = None ,
2023-12-12 04:27:13 +00:00
dtype = None , device = None , operations = ops
2023-11-24 00:41:33 +00:00
) :
super ( ) . __init__ (
in_channels ,
n_heads ,
d_head ,
depth = depth ,
dropout = dropout ,
use_checkpoint = checkpoint ,
context_dim = context_dim ,
use_linear = use_linear ,
disable_self_attn = disable_self_attn ,
2024-05-14 19:18:00 +00:00
attn_precision = attn_precision ,
2023-11-24 00:41:33 +00:00
dtype = dtype , device = device , operations = operations
)
self . time_depth = time_depth
self . depth = depth
self . max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int ( time_mix_d_head * n_time_mix_heads )
inner_dim = n_heads * d_head
if use_spatial_context :
time_context_dim = context_dim
self . time_stack = nn . ModuleList (
[
BasicTransformerBlock (
inner_dim ,
n_time_mix_heads ,
time_mix_d_head ,
dropout = dropout ,
context_dim = time_context_dim ,
# timesteps=timesteps,
checkpoint = checkpoint ,
ff_in = ff_in ,
inner_dim = time_mix_inner_dim ,
disable_self_attn = disable_self_attn ,
disable_temporal_crossattention = disable_temporal_crossattention ,
2024-05-14 19:18:00 +00:00
attn_precision = attn_precision ,
2023-11-24 00:41:33 +00:00
dtype = dtype , device = device , operations = operations
)
for _ in range ( self . depth )
]
)
assert len ( self . time_stack ) == len ( self . transformer_blocks )
self . use_spatial_context = use_spatial_context
self . in_channels = in_channels
time_embed_dim = self . in_channels * 4
self . time_pos_embed = nn . Sequential (
operations . Linear ( self . in_channels , time_embed_dim , dtype = dtype , device = device ) ,
nn . SiLU ( ) ,
operations . Linear ( time_embed_dim , self . in_channels , dtype = dtype , device = device ) ,
)
self . time_mixer = AlphaBlender (
alpha = merge_factor , merge_strategy = merge_strategy
)
def forward (
self ,
x : torch . Tensor ,
context : Optional [ torch . Tensor ] = None ,
time_context : Optional [ torch . Tensor ] = None ,
timesteps : Optional [ int ] = None ,
image_only_indicator : Optional [ torch . Tensor ] = None ,
transformer_options = { }
) - > torch . Tensor :
_ , _ , h , w = x . shape
x_in = x
spatial_context = None
if exists ( context ) :
spatial_context = context
if self . use_spatial_context :
assert (
context . ndim == 3
) , f " n dims of spatial context should be 3 but are { context . ndim } "
if time_context is None :
time_context = context
time_context_first_timestep = time_context [ : : timesteps ]
time_context = repeat (
time_context_first_timestep , " b ... -> (b n) ... " , n = h * w
)
elif time_context is not None and not self . use_spatial_context :
time_context = repeat ( time_context , " b ... -> (b n) ... " , n = h * w )
if time_context . ndim == 2 :
time_context = rearrange ( time_context , " b c -> b 1 c " )
x = self . norm ( x )
if not self . use_linear :
x = self . proj_in ( x )
x = rearrange ( x , " b c h w -> b (h w) c " )
if self . use_linear :
x = self . proj_in ( x )
num_frames = torch . arange ( timesteps , device = x . device )
num_frames = repeat ( num_frames , " t -> b t " , b = x . shape [ 0 ] / / timesteps )
num_frames = rearrange ( num_frames , " b t -> (b t) " )
t_emb = timestep_embedding ( num_frames , self . in_channels , repeat_only = False , max_period = self . max_time_embed_period ) . to ( x . dtype )
emb = self . time_pos_embed ( t_emb )
emb = emb [ : , None , : ]
for it_ , ( block , mix_block ) in enumerate (
zip ( self . transformer_blocks , self . time_stack )
) :
transformer_options [ " block_index " ] = it_
x = block (
x ,
context = spatial_context ,
transformer_options = transformer_options ,
)
x_mix = x
x_mix = x_mix + emb
B , S , C = x_mix . shape
x_mix = rearrange ( x_mix , " (b t) s c -> (b s) t c " , t = timesteps )
x_mix = mix_block ( x_mix , context = time_context ) #TODO: transformer_options
x_mix = rearrange (
x_mix , " (b s) t c -> (b t) s c " , s = S , b = B / / timesteps , c = C , t = timesteps
)
x = self . time_mixer ( x_spatial = x , x_temporal = x_mix , image_only_indicator = image_only_indicator )
if self . use_linear :
x = self . proj_out ( x )
x = rearrange ( x , " b (h w) c -> b c h w " , h = h , w = w )
if not self . use_linear :
x = self . proj_out ( x )
out = x + x_in
return out