2025-03-19 20:19:50 +00:00
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
import torch
import torch . nn as nn
import torch . nn . functional as F
from typing import Union , Tuple , List , Callable , Optional
import numpy as np
from einops import repeat , rearrange
from tqdm import tqdm
import logging
import comfy . ops
ops = comfy . ops . disable_weight_init
def generate_dense_grid_points (
bbox_min : np . ndarray ,
bbox_max : np . ndarray ,
octree_resolution : int ,
indexing : str = " ij " ,
) :
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np . linspace ( bbox_min [ 0 ] , bbox_max [ 0 ] , int ( num_cells ) + 1 , dtype = np . float32 )
y = np . linspace ( bbox_min [ 1 ] , bbox_max [ 1 ] , int ( num_cells ) + 1 , dtype = np . float32 )
z = np . linspace ( bbox_min [ 2 ] , bbox_max [ 2 ] , int ( num_cells ) + 1 , dtype = np . float32 )
[ xs , ys , zs ] = np . meshgrid ( x , y , z , indexing = indexing )
xyz = np . stack ( ( xs , ys , zs ) , axis = - 1 )
grid_size = [ int ( num_cells ) + 1 , int ( num_cells ) + 1 , int ( num_cells ) + 1 ]
return xyz , grid_size , length
class VanillaVolumeDecoder :
@torch.no_grad ( )
def __call__ (
self ,
latents : torch . FloatTensor ,
geo_decoder : Callable ,
bounds : Union [ Tuple [ float ] , List [ float ] , float ] = 1.01 ,
num_chunks : int = 10000 ,
octree_resolution : int = None ,
enable_pbar : bool = True ,
* * kwargs ,
) :
device = latents . device
dtype = latents . dtype
batch_size = latents . shape [ 0 ]
# 1. generate query points
if isinstance ( bounds , float ) :
bounds = [ - bounds , - bounds , - bounds , bounds , bounds , bounds ]
bbox_min , bbox_max = np . array ( bounds [ 0 : 3 ] ) , np . array ( bounds [ 3 : 6 ] )
xyz_samples , grid_size , length = generate_dense_grid_points (
bbox_min = bbox_min ,
bbox_max = bbox_max ,
octree_resolution = octree_resolution ,
indexing = " ij "
)
xyz_samples = torch . from_numpy ( xyz_samples ) . to ( device , dtype = dtype ) . contiguous ( ) . reshape ( - 1 , 3 )
# 2. latents to 3d volume
batch_logits = [ ]
for start in tqdm ( range ( 0 , xyz_samples . shape [ 0 ] , num_chunks ) , desc = " Volume Decoding " ,
disable = not enable_pbar ) :
chunk_queries = xyz_samples [ start : start + num_chunks , : ]
chunk_queries = repeat ( chunk_queries , " p c -> b p c " , b = batch_size )
logits = geo_decoder ( queries = chunk_queries , latents = latents )
batch_logits . append ( logits )
grid_logits = torch . cat ( batch_logits , dim = 1 )
grid_logits = grid_logits . view ( ( batch_size , * grid_size ) ) . float ( )
return grid_logits
class FourierEmbedder ( nn . Module ) :
""" The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of ` x [ . . . , i ] ` into :
[
sin ( x [ . . . , i ] ) ,
sin ( f_1 * x [ . . . , i ] ) ,
sin ( f_2 * x [ . . . , i ] ) ,
. . .
sin ( f_N * x [ . . . , i ] ) ,
cos ( x [ . . . , i ] ) ,
cos ( f_1 * x [ . . . , i ] ) ,
cos ( f_2 * x [ . . . , i ] ) ,
. . .
cos ( f_N * x [ . . . , i ] ) ,
x [ . . . , i ] # only present if include_input is True.
] , here f_i is the frequency .
Denote the space is [ 0 / num_freqs , 1 / num_freqs , 2 / num_freqs , 3 / num_freqs , . . . , ( num_freqs - 1 ) / num_freqs ] .
If logspace is True , then the frequency f_i is [ 2 ^ ( 0 / num_freqs ) , . . . , 2 ^ ( i / num_freqs ) , . . . ] ;
Otherwise , the frequencies are linearly spaced between [ 1.0 , 2 ^ ( num_freqs - 1 ) ] .
Args :
num_freqs ( int ) : the number of frequencies , default is 6 ;
logspace ( bool ) : If logspace is True , then the frequency f_i is [ . . . , 2 ^ ( i / num_freqs ) , . . . ] ,
otherwise , the frequencies are linearly spaced between [ 1.0 , 2 ^ ( num_freqs - 1 ) ] ;
input_dim ( int ) : the input dimension , default is 3 ;
include_input ( bool ) : include the input tensor or not , default is True .
Attributes :
frequencies ( torch . Tensor ) : If logspace is True , then the frequency f_i is [ . . . , 2 ^ ( i / num_freqs ) , . . . ] ,
otherwise , the frequencies are linearly spaced between [ 1.0 , 2 ^ ( num_freqs - 1 ) ;
out_dim ( int ) : the embedding size , if include_input is True , it is input_dim * ( num_freqs * 2 + 1 ) ,
otherwise , it is input_dim * num_freqs * 2.
"""
def __init__ ( self ,
num_freqs : int = 6 ,
logspace : bool = True ,
input_dim : int = 3 ,
include_input : bool = True ,
include_pi : bool = True ) - > None :
""" The initialization """
super ( ) . __init__ ( )
if logspace :
frequencies = 2.0 * * torch . arange (
num_freqs ,
dtype = torch . float32
)
else :
frequencies = torch . linspace (
1.0 ,
2.0 * * ( num_freqs - 1 ) ,
num_freqs ,
dtype = torch . float32
)
if include_pi :
frequencies * = torch . pi
self . register_buffer ( " frequencies " , frequencies , persistent = False )
self . include_input = include_input
self . num_freqs = num_freqs
self . out_dim = self . get_dims ( input_dim )
def get_dims ( self , input_dim ) :
temp = 1 if self . include_input or self . num_freqs == 0 else 0
out_dim = input_dim * ( self . num_freqs * 2 + temp )
return out_dim
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
""" Forward process.
Args :
x : tensor of shape [ . . . , dim ]
Returns :
embedding : an embedding of ` x ` of shape [ . . . , dim * ( num_freqs * 2 + temp ) ]
where temp is 1 if include_input is True and 0 otherwise .
"""
if self . num_freqs > 0 :
embed = ( x [ . . . , None ] . contiguous ( ) * self . frequencies . to ( device = x . device , dtype = x . dtype ) ) . view ( * x . shape [ : - 1 ] , - 1 )
if self . include_input :
return torch . cat ( ( x , embed . sin ( ) , embed . cos ( ) ) , dim = - 1 )
else :
return torch . cat ( ( embed . sin ( ) , embed . cos ( ) ) , dim = - 1 )
else :
return x
class CrossAttentionProcessor :
def __call__ ( self , attn , q , k , v ) :
out = F . scaled_dot_product_attention ( q , k , v )
return out
class DropPath ( nn . Module ) :
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__ ( self , drop_prob : float = 0. , scale_by_keep : bool = True ) :
super ( DropPath , self ) . __init__ ( )
self . drop_prob = drop_prob
self . scale_by_keep = scale_by_keep
def forward ( self , x ) :
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet , etc networks , however ,
the original name is misleading as ' Drop Connect ' is a different form of dropout in a separate paper . . .
See discussion : https : / / github . com / tensorflow / tpu / issues / 494 #issuecomment-532968956 ... I've opted for
changing the layer and argument names to ' drop path ' rather than mix DropConnect as a layer name and use
' survival rate ' as the argument .
"""
if self . drop_prob == 0. or not self . training :
return x
keep_prob = 1 - self . drop_prob
shape = ( x . shape [ 0 ] , ) + ( 1 , ) * ( x . ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x . new_empty ( shape ) . bernoulli_ ( keep_prob )
if keep_prob > 0.0 and self . scale_by_keep :
random_tensor . div_ ( keep_prob )
return x * random_tensor
def extra_repr ( self ) :
return f ' drop_prob= { round ( self . drop_prob , 3 ) : 0.3f } '
class MLP ( nn . Module ) :
def __init__ (
self , * ,
width : int ,
expand_ratio : int = 4 ,
output_width : int = None ,
drop_path_rate : float = 0.0
) :
super ( ) . __init__ ( )
self . width = width
self . c_fc = ops . Linear ( width , width * expand_ratio )
self . c_proj = ops . Linear ( width * expand_ratio , output_width if output_width is not None else width )
self . gelu = nn . GELU ( )
self . drop_path = DropPath ( drop_path_rate ) if drop_path_rate > 0. else nn . Identity ( )
def forward ( self , x ) :
return self . drop_path ( self . c_proj ( self . gelu ( self . c_fc ( x ) ) ) )
class QKVMultiheadCrossAttention ( nn . Module ) :
def __init__ (
self ,
* ,
heads : int ,
width = None ,
qk_norm = False ,
norm_layer = ops . LayerNorm
) :
super ( ) . __init__ ( )
self . heads = heads
self . q_norm = norm_layer ( width / / heads , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn . Identity ( )
self . k_norm = norm_layer ( width / / heads , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn . Identity ( )
self . attn_processor = CrossAttentionProcessor ( )
def forward ( self , q , kv ) :
_ , n_ctx , _ = q . shape
bs , n_data , width = kv . shape
attn_ch = width / / self . heads / / 2
q = q . view ( bs , n_ctx , self . heads , - 1 )
kv = kv . view ( bs , n_data , self . heads , - 1 )
k , v = torch . split ( kv , attn_ch , dim = - 1 )
q = self . q_norm ( q )
k = self . k_norm ( k )
q , k , v = map ( lambda t : rearrange ( t , ' b n h d -> b h n d ' , h = self . heads ) , ( q , k , v ) )
out = self . attn_processor ( self , q , k , v )
out = out . transpose ( 1 , 2 ) . reshape ( bs , n_ctx , - 1 )
return out
class MultiheadCrossAttention ( nn . Module ) :
def __init__ (
self ,
* ,
width : int ,
heads : int ,
qkv_bias : bool = True ,
data_width : Optional [ int ] = None ,
norm_layer = ops . LayerNorm ,
qk_norm : bool = False ,
kv_cache : bool = False ,
) :
super ( ) . __init__ ( )
self . width = width
self . heads = heads
self . data_width = width if data_width is None else data_width
self . c_q = ops . Linear ( width , width , bias = qkv_bias )
self . c_kv = ops . Linear ( self . data_width , width * 2 , bias = qkv_bias )
self . c_proj = ops . Linear ( width , width )
self . attention = QKVMultiheadCrossAttention (
heads = heads ,
width = width ,
norm_layer = norm_layer ,
qk_norm = qk_norm
)
self . kv_cache = kv_cache
self . data = None
def forward ( self , x , data ) :
x = self . c_q ( x )
if self . kv_cache :
if self . data is None :
self . data = self . c_kv ( data )
logging . info ( ' Save kv cache,this should be called only once for one mesh ' )
data = self . data
else :
data = self . c_kv ( data )
x = self . attention ( x , data )
x = self . c_proj ( x )
return x
class ResidualCrossAttentionBlock ( nn . Module ) :
def __init__ (
self ,
* ,
width : int ,
heads : int ,
mlp_expand_ratio : int = 4 ,
data_width : Optional [ int ] = None ,
qkv_bias : bool = True ,
norm_layer = ops . LayerNorm ,
qk_norm : bool = False
) :
super ( ) . __init__ ( )
if data_width is None :
data_width = width
self . attn = MultiheadCrossAttention (
width = width ,
heads = heads ,
data_width = data_width ,
qkv_bias = qkv_bias ,
norm_layer = norm_layer ,
qk_norm = qk_norm
)
self . ln_1 = norm_layer ( width , elementwise_affine = True , eps = 1e-6 )
self . ln_2 = norm_layer ( data_width , elementwise_affine = True , eps = 1e-6 )
self . ln_3 = norm_layer ( width , elementwise_affine = True , eps = 1e-6 )
self . mlp = MLP ( width = width , expand_ratio = mlp_expand_ratio )
def forward ( self , x : torch . Tensor , data : torch . Tensor ) :
x = x + self . attn ( self . ln_1 ( x ) , self . ln_2 ( data ) )
x = x + self . mlp ( self . ln_3 ( x ) )
return x
class QKVMultiheadAttention ( nn . Module ) :
def __init__ (
self ,
* ,
heads : int ,
width = None ,
qk_norm = False ,
norm_layer = ops . LayerNorm
) :
super ( ) . __init__ ( )
self . heads = heads
self . q_norm = norm_layer ( width / / heads , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn . Identity ( )
self . k_norm = norm_layer ( width / / heads , elementwise_affine = True , eps = 1e-6 ) if qk_norm else nn . Identity ( )
def forward ( self , qkv ) :
bs , n_ctx , width = qkv . shape
attn_ch = width / / self . heads / / 3
qkv = qkv . view ( bs , n_ctx , self . heads , - 1 )
q , k , v = torch . split ( qkv , attn_ch , dim = - 1 )
q = self . q_norm ( q )
k = self . k_norm ( k )
q , k , v = map ( lambda t : rearrange ( t , ' b n h d -> b h n d ' , h = self . heads ) , ( q , k , v ) )
out = F . scaled_dot_product_attention ( q , k , v ) . transpose ( 1 , 2 ) . reshape ( bs , n_ctx , - 1 )
return out
class MultiheadAttention ( nn . Module ) :
def __init__ (
self ,
* ,
width : int ,
heads : int ,
qkv_bias : bool ,
norm_layer = ops . LayerNorm ,
qk_norm : bool = False ,
drop_path_rate : float = 0.0
) :
super ( ) . __init__ ( )
self . width = width
self . heads = heads
self . c_qkv = ops . Linear ( width , width * 3 , bias = qkv_bias )
self . c_proj = ops . Linear ( width , width )
self . attention = QKVMultiheadAttention (
heads = heads ,
width = width ,
norm_layer = norm_layer ,
qk_norm = qk_norm
)
self . drop_path = DropPath ( drop_path_rate ) if drop_path_rate > 0. else nn . Identity ( )
def forward ( self , x ) :
x = self . c_qkv ( x )
x = self . attention ( x )
x = self . drop_path ( self . c_proj ( x ) )
return x
class ResidualAttentionBlock ( nn . Module ) :
def __init__ (
self ,
* ,
width : int ,
heads : int ,
qkv_bias : bool = True ,
norm_layer = ops . LayerNorm ,
qk_norm : bool = False ,
drop_path_rate : float = 0.0 ,
) :
super ( ) . __init__ ( )
self . attn = MultiheadAttention (
width = width ,
heads = heads ,
qkv_bias = qkv_bias ,
norm_layer = norm_layer ,
qk_norm = qk_norm ,
drop_path_rate = drop_path_rate
)
self . ln_1 = norm_layer ( width , elementwise_affine = True , eps = 1e-6 )
self . mlp = MLP ( width = width , drop_path_rate = drop_path_rate )
self . ln_2 = norm_layer ( width , elementwise_affine = True , eps = 1e-6 )
def forward ( self , x : torch . Tensor ) :
x = x + self . attn ( self . ln_1 ( x ) )
x = x + self . mlp ( self . ln_2 ( x ) )
return x
class Transformer ( nn . Module ) :
def __init__ (
self ,
* ,
width : int ,
layers : int ,
heads : int ,
qkv_bias : bool = True ,
norm_layer = ops . LayerNorm ,
qk_norm : bool = False ,
drop_path_rate : float = 0.0
) :
super ( ) . __init__ ( )
self . width = width
self . layers = layers
self . resblocks = nn . ModuleList (
[
ResidualAttentionBlock (
width = width ,
heads = heads ,
qkv_bias = qkv_bias ,
norm_layer = norm_layer ,
qk_norm = qk_norm ,
drop_path_rate = drop_path_rate
)
for _ in range ( layers )
]
)
def forward ( self , x : torch . Tensor ) :
for block in self . resblocks :
x = block ( x )
return x
class CrossAttentionDecoder ( nn . Module ) :
def __init__ (
self ,
* ,
out_channels : int ,
fourier_embedder : FourierEmbedder ,
width : int ,
heads : int ,
mlp_expand_ratio : int = 4 ,
downsample_ratio : int = 1 ,
enable_ln_post : bool = True ,
qkv_bias : bool = True ,
qk_norm : bool = False ,
label_type : str = " binary "
) :
super ( ) . __init__ ( )
self . enable_ln_post = enable_ln_post
self . fourier_embedder = fourier_embedder
self . downsample_ratio = downsample_ratio
self . query_proj = ops . Linear ( self . fourier_embedder . out_dim , width )
if self . downsample_ratio != 1 :
self . latents_proj = ops . Linear ( width * downsample_ratio , width )
if self . enable_ln_post == False :
qk_norm = False
self . cross_attn_decoder = ResidualCrossAttentionBlock (
width = width ,
mlp_expand_ratio = mlp_expand_ratio ,
heads = heads ,
qkv_bias = qkv_bias ,
qk_norm = qk_norm
)
if self . enable_ln_post :
self . ln_post = ops . LayerNorm ( width )
self . output_proj = ops . Linear ( width , out_channels )
self . label_type = label_type
self . count = 0
def forward ( self , queries = None , query_embeddings = None , latents = None ) :
if query_embeddings is None :
query_embeddings = self . query_proj ( self . fourier_embedder ( queries ) . to ( latents . dtype ) )
self . count + = query_embeddings . shape [ 1 ]
if self . downsample_ratio != 1 :
latents = self . latents_proj ( latents )
x = self . cross_attn_decoder ( query_embeddings , latents )
if self . enable_ln_post :
x = self . ln_post ( x )
occ = self . output_proj ( x )
return occ
class ShapeVAE ( nn . Module ) :
def __init__ (
self ,
* ,
embed_dim : int ,
width : int ,
heads : int ,
num_decoder_layers : int ,
geo_decoder_downsample_ratio : int = 1 ,
geo_decoder_mlp_expand_ratio : int = 4 ,
geo_decoder_ln_post : bool = True ,
num_freqs : int = 8 ,
include_pi : bool = True ,
qkv_bias : bool = True ,
qk_norm : bool = False ,
label_type : str = " binary " ,
drop_path_rate : float = 0.0 ,
scale_factor : float = 1.0 ,
) :
super ( ) . __init__ ( )
self . geo_decoder_ln_post = geo_decoder_ln_post
self . fourier_embedder = FourierEmbedder ( num_freqs = num_freqs , include_pi = include_pi )
self . post_kl = ops . Linear ( embed_dim , width )
self . transformer = Transformer (
width = width ,
layers = num_decoder_layers ,
heads = heads ,
qkv_bias = qkv_bias ,
qk_norm = qk_norm ,
drop_path_rate = drop_path_rate
)
self . geo_decoder = CrossAttentionDecoder (
fourier_embedder = self . fourier_embedder ,
out_channels = 1 ,
mlp_expand_ratio = geo_decoder_mlp_expand_ratio ,
downsample_ratio = geo_decoder_downsample_ratio ,
enable_ln_post = self . geo_decoder_ln_post ,
width = width / / geo_decoder_downsample_ratio ,
heads = heads / / geo_decoder_downsample_ratio ,
qkv_bias = qkv_bias ,
qk_norm = qk_norm ,
label_type = label_type ,
)
self . volume_decoder = VanillaVolumeDecoder ( )
self . scale_factor = scale_factor
def decode ( self , latents , * * kwargs ) :
latents = self . post_kl ( latents . movedim ( - 2 , - 1 ) )
latents = self . transformer ( latents )
bounds = kwargs . get ( " bounds " , 1.01 )
num_chunks = kwargs . get ( " num_chunks " , 8000 )
octree_resolution = kwargs . get ( " octree_resolution " , 256 )
enable_pbar = kwargs . get ( " enable_pbar " , True )
grid_logits = self . volume_decoder ( latents , self . geo_decoder , bounds = bounds , num_chunks = num_chunks , octree_resolution = octree_resolution , enable_pbar = enable_pbar )
2025-03-19 23:55:24 +00:00
return grid_logits . movedim ( - 2 , - 1 )
2025-03-19 20:19:50 +00:00
def encode ( self , x ) :
return None