mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-12 02:14:19 +00:00
Initial Hunyuan3Dv2 implementation.
Supports the multiview, mini, turbo models and VAEs.
This commit is contained in:
parent
3b19fc76e3
commit
11f1b41bab
@ -456,3 +456,13 @@ class Wan21(LatentFormat):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
135
comfy/ldm/hunyuan3d/model.py
Normal file
135
comfy/ldm/hunyuan3d/model.py
Normal file
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=64,
|
||||
context_in_dim=1536,
|
||||
hidden_size=1024,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=16,
|
||||
depth=16,
|
||||
depth_single_blocks=32,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||
)
|
||||
|
||||
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
|
||||
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
|
||||
)
|
||||
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
)
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
img = self.latent_in(x)
|
||||
|
||||
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
|
||||
if self.guidance_in is not None:
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
|
||||
|
||||
txt = self.cond_in(txt)
|
||||
pe = None
|
||||
attn_mask = None
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
return img.movedim(-2, -1) * (-1.0)
|
587
comfy/ldm/hunyuan3d/vae.py
Normal file
587
comfy/ldm/hunyuan3d/vae.py
Normal file
@ -0,0 +1,587 @@
|
||||
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
|
||||
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
], here f_i is the frequency.
|
||||
|
||||
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||
|
||||
Args:
|
||||
num_freqs (int): the number of frequencies, default is 6;
|
||||
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||
input_dim (int): the input dimension, default is 3;
|
||||
include_input (bool): include the input tensor or not, default is True.
|
||||
|
||||
Attributes:
|
||||
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||
|
||||
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||
otherwise, it is input_dim * num_freqs * 2.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_freqs: int = 6,
|
||||
logspace: bool = True,
|
||||
input_dim: int = 3,
|
||||
include_input: bool = True,
|
||||
include_pi: bool = True) -> None:
|
||||
|
||||
"""The initialization"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (num_freqs - 1),
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if include_pi:
|
||||
frequencies *= torch.pi
|
||||
|
||||
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||
self.include_input = include_input
|
||||
self.num_freqs = num_freqs
|
||||
|
||||
self.out_dim = self.get_dims(input_dim)
|
||||
|
||||
def get_dims(self, input_dim):
|
||||
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||
|
||||
return out_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Forward process.
|
||||
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
|
||||
Returns:
|
||||
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||
where temp is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
|
||||
if self.num_freqs > 0:
|
||||
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
|
||||
if self.include_input:
|
||||
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if self.drop_prob == 0. or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, *,
|
||||
width: int,
|
||||
expand_ratio: int = 4,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = ops.Linear(width, width * expand_ratio)
|
||||
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
kv_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = ops.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.kv_cache = kv_cache
|
||||
self.data = None
|
||||
|
||||
def forward(self, x, data):
|
||||
x = self.c_q(x)
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logging.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
x = self.attention(x, data)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.heads // 3
|
||||
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x)
|
||||
x = self.drop_path(self.c_proj(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=ops.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for block in self.resblocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
downsample_ratio: int = 1,
|
||||
enable_ln_post: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_ln_post = enable_ln_post
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
mlp_expand_ratio=mlp_expand_ratio,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if self.enable_ln_post:
|
||||
self.ln_post = ops.LayerNorm(width)
|
||||
self.output_proj = ops.Linear(width, out_channels)
|
||||
self.label_type = label_type
|
||||
self.count = 0
|
||||
|
||||
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||
self.count += query_embeddings.shape[1]
|
||||
if self.downsample_ratio != 1:
|
||||
latents = self.latents_proj(latents)
|
||||
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||
if self.enable_ln_post:
|
||||
x = self.ln_post(x)
|
||||
occ = self.output_proj(x)
|
||||
return occ
|
||||
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||
downsample_ratio=geo_decoder_downsample_ratio,
|
||||
enable_ln_post=self.geo_decoder_ln_post,
|
||||
width=width // geo_decoder_downsample_ratio,
|
||||
heads=heads // geo_decoder_downsample_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
)
|
||||
|
||||
self.volume_decoder = VanillaVolumeDecoder()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def decode(self, latents, **kwargs):
|
||||
latents = self.post_kl(latents.movedim(-2, -1))
|
||||
latents = self.transformer(latents)
|
||||
|
||||
bounds = kwargs.get("bounds", 1.01)
|
||||
num_chunks = kwargs.get("num_chunks", 8000)
|
||||
octree_resolution = kwargs.get("octree_resolution", 256)
|
||||
enable_pbar = kwargs.get("enable_pbar", True)
|
||||
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
@ -36,6 +36,7 @@ import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1013,3 +1014,18 @@ class WAN21(BaseModel):
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guidance = kwargs.get("guidance", 5.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "t2v"
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2"
|
||||
dit_config["in_channels"] = in_shape[1]
|
||||
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
15
comfy/sd.py
15
comfy/sd.py
@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@ -412,6 +413,16 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -498,7 +509,7 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
try:
|
||||
@ -510,7 +521,7 @@ class VAE:
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
|
@ -959,6 +959,42 @@ class WAN21_I2V(WAN21_T2V):
|
||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Hunyuan3Dv2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
"depth": 8,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
410
comfy_extras/nodes_hunyuan3d.py
Normal file
410
comfy_extras/nodes_hunyuan3d.py
Normal file
@ -0,0 +1,410 @@
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class EmptyLatentHunyuan3Dv2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def generate(self, resolution, batch_size):
|
||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||
|
||||
|
||||
class Hunyuan3Dv2Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision_output):
|
||||
embeds = clip_vision_output.last_hidden_state
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class Hunyuan3Dv2ConditioningMultiView:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {},
|
||||
"optional": {"front": ("CLIP_VISION_OUTPUT",),
|
||||
"left": ("CLIP_VISION_OUTPUT",),
|
||||
"back": ("CLIP_VISION_OUTPUT",),
|
||||
"right": ("CLIP_VISION_OUTPUT",), }}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, front=None, left=None, back=None, right=None):
|
||||
all_embeds = [front, left, back, right]
|
||||
out = []
|
||||
pos_embeds = None
|
||||
for i, e in enumerate(all_embeds):
|
||||
if e is not None:
|
||||
if pos_embeds is None:
|
||||
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
|
||||
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
|
||||
|
||||
embeds = torch.cat(out, dim=1)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
||||
class VAEDecodeHunyuan3D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ),
|
||||
"vae": ("VAE", ),
|
||||
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
|
||||
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
|
||||
}}
|
||||
RETURN_TYPES = ("VOXEL",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
|
||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
voxels = voxels.to(device)
|
||||
|
||||
binary = (voxels > threshold).float()
|
||||
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
||||
|
||||
D, H, W = binary.shape
|
||||
|
||||
neighbors = torch.tensor([
|
||||
[0, 0, 1],
|
||||
[0, 0, -1],
|
||||
[0, 1, 0],
|
||||
[0, -1, 0],
|
||||
[1, 0, 0],
|
||||
[-1, 0, 0]
|
||||
], device=device)
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(D, device=device),
|
||||
torch.arange(H, device=device),
|
||||
torch.arange(W, device=device),
|
||||
indexing='ij'
|
||||
)
|
||||
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
||||
|
||||
solid_mask = binary.flatten() > 0
|
||||
solid_indices = voxel_indices[solid_mask]
|
||||
|
||||
corner_offsets = [
|
||||
torch.tensor([
|
||||
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
|
||||
], device=device),
|
||||
torch.tensor([
|
||||
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
|
||||
], device=device)
|
||||
]
|
||||
|
||||
all_vertices = []
|
||||
all_indices = []
|
||||
|
||||
vertex_count = 0
|
||||
|
||||
for face_idx, offset in enumerate(neighbors):
|
||||
neighbor_indices = solid_indices + offset
|
||||
|
||||
padded_indices = neighbor_indices + 1
|
||||
|
||||
is_exposed = padded[
|
||||
padded_indices[:, 0],
|
||||
padded_indices[:, 1],
|
||||
padded_indices[:, 2]
|
||||
] == 0
|
||||
|
||||
if not is_exposed.any():
|
||||
continue
|
||||
|
||||
exposed_indices = solid_indices[is_exposed]
|
||||
|
||||
corners = corner_offsets[face_idx].unsqueeze(0)
|
||||
|
||||
face_vertices = exposed_indices.unsqueeze(1) + corners
|
||||
|
||||
all_vertices.append(face_vertices.reshape(-1, 3))
|
||||
|
||||
num_faces = exposed_indices.shape[0]
|
||||
face_indices = torch.arange(
|
||||
vertex_count,
|
||||
vertex_count + 4 * num_faces,
|
||||
device=device
|
||||
).reshape(-1, 4)
|
||||
|
||||
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 1]], dim=1))
|
||||
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 3], face_indices[:, 2]], dim=1))
|
||||
|
||||
vertex_count += 4 * num_faces
|
||||
|
||||
vertices = torch.cat(all_vertices, dim=0)
|
||||
faces = torch.cat(all_indices, dim=0)
|
||||
|
||||
v_min = 0
|
||||
v_max = max(voxels.shape)
|
||||
|
||||
vertices = vertices - (v_min + v_max) / 2
|
||||
|
||||
scale = (v_max - v_min) / 2
|
||||
if scale > 0:
|
||||
vertices = vertices / scale
|
||||
|
||||
return vertices, faces
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices, faces):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class VoxelToMeshBasic:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, threshold):
|
||||
vertices = []
|
||||
faces = []
|
||||
for x in voxel.data:
|
||||
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
Parameters:
|
||||
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
|
||||
filepath: str - Output filepath (should end with .glb)
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
||||
faces_np = faces.cpu().numpy().astype(np.uint32)
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b'\x00' * padding_length
|
||||
|
||||
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||
|
||||
buffer_data = vertices_buffer_padded + indices_buffer_padded
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
vertices_byte_offset = 0
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
"buffers": [
|
||||
{
|
||||
"byteLength": len(buffer_data)
|
||||
}
|
||||
],
|
||||
"bufferViews": [
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": vertices_byte_offset,
|
||||
"byteLength": vertices_byte_length,
|
||||
"target": 34962 # ARRAY_BUFFER
|
||||
},
|
||||
{
|
||||
"buffer": 0,
|
||||
"byteOffset": indices_byte_offset,
|
||||
"byteLength": indices_byte_length,
|
||||
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
||||
}
|
||||
],
|
||||
"accessors": [
|
||||
{
|
||||
"bufferView": 0,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126, # FLOAT
|
||||
"count": len(vertices_np),
|
||||
"type": "VEC3",
|
||||
"max": vertices_np.max(axis=0).tolist(),
|
||||
"min": vertices_np.min(axis=0).tolist()
|
||||
},
|
||||
{
|
||||
"bufferView": 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5125, # UNSIGNED_INT
|
||||
"count": faces_np.size,
|
||||
"type": "SCALAR"
|
||||
}
|
||||
],
|
||||
"meshes": [
|
||||
{
|
||||
"primitives": [
|
||||
{
|
||||
"attributes": {
|
||||
"POSITION": 0
|
||||
},
|
||||
"indices": 1,
|
||||
"mode": 4 # TRIANGLES
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"mesh": 0
|
||||
}
|
||||
],
|
||||
"scenes": [
|
||||
{
|
||||
"nodes": [0]
|
||||
}
|
||||
],
|
||||
"scene": 0
|
||||
}
|
||||
|
||||
if metadata is not None:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
|
||||
# Convert the JSON to bytes
|
||||
gltf_json = json.dumps(gltf).encode('utf8')
|
||||
|
||||
def pad_json_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b' ' * padding_length
|
||||
|
||||
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
||||
|
||||
# Create the GLB header
|
||||
# Magic glTF
|
||||
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
||||
|
||||
# Create JSON chunk header (chunk type 0)
|
||||
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
||||
|
||||
# Create BIN chunk header (chunk type 1)
|
||||
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
||||
|
||||
# Write the GLB file
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(glb_header)
|
||||
f.write(json_chunk_header)
|
||||
f.write(gltf_json_padded)
|
||||
f.write(bin_chunk_header)
|
||||
f.write(buffer_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"mesh": ("MESH", ),
|
||||
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return {"ui": {"3d": results}}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
|
||||
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
|
||||
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||
"SaveGLB": SaveGLB,
|
||||
}
|
Loading…
Reference in New Issue
Block a user