mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-18 18:33:30 +00:00
Initial Hunyuan3Dv2 implementation.
Supports the multiview, mini, turbo models and VAEs.
This commit is contained in:
parent
3b19fc76e3
commit
11f1b41bab
@ -456,3 +456,13 @@ class Wan21(LatentFormat):
|
|||||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
return latent * latents_std / self.scale_factor + latents_mean
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|
||||||
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
scale_factor = 0.9990943042622529
|
||||||
|
|
||||||
|
class Hunyuan3Dv2mini(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
scale_factor = 1.0188137142395404
|
||||||
|
135
comfy/ldm/hunyuan3d/model.py
Normal file
135
comfy/ldm/hunyuan3d/model.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from comfy.ldm.flux.layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
LastLayer,
|
||||||
|
MLPEmbedder,
|
||||||
|
SingleStreamBlock,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Hunyuan3Dv2(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=64,
|
||||||
|
context_in_dim=1536,
|
||||||
|
hidden_size=1024,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=16,
|
||||||
|
depth=16,
|
||||||
|
depth_single_blocks=32,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=False,
|
||||||
|
image_model=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
if hidden_size % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
|
||||||
|
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.guidance_in = (
|
||||||
|
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
|
||||||
|
)
|
||||||
|
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
hidden_size,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(
|
||||||
|
hidden_size,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||||
|
x = x.movedim(-1, -2)
|
||||||
|
timestep = 1.0 - timestep
|
||||||
|
txt = context
|
||||||
|
img = self.latent_in(x)
|
||||||
|
|
||||||
|
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
|
||||||
|
if self.guidance_in is not None:
|
||||||
|
if guidance is not None:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
|
||||||
|
|
||||||
|
txt = self.cond_in(txt)
|
||||||
|
pe = None
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"], out["txt"] = block(img=args["img"],
|
||||||
|
txt=args["txt"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
|
"txt": txt,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
txt = out["txt"]
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img, txt = block(img=img,
|
||||||
|
txt=txt,
|
||||||
|
vec=vec,
|
||||||
|
pe=pe,
|
||||||
|
attn_mask=attn_mask)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
if ("single_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||||
|
|
||||||
|
img = img[:, txt.shape[1]:, ...]
|
||||||
|
img = self.final_layer(img, vec)
|
||||||
|
return img.movedim(-2, -1) * (-1.0)
|
587
comfy/ldm/hunyuan3d/vae.py
Normal file
587
comfy/ldm/hunyuan3d/vae.py
Normal file
@ -0,0 +1,587 @@
|
|||||||
|
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
|
||||||
|
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Union, Tuple, List, Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat, rearrange
|
||||||
|
from tqdm import tqdm
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def generate_dense_grid_points(
|
||||||
|
bbox_min: np.ndarray,
|
||||||
|
bbox_max: np.ndarray,
|
||||||
|
octree_resolution: int,
|
||||||
|
indexing: str = "ij",
|
||||||
|
):
|
||||||
|
length = bbox_max - bbox_min
|
||||||
|
num_cells = octree_resolution
|
||||||
|
|
||||||
|
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||||
|
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||||
|
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||||
|
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||||
|
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||||
|
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||||
|
|
||||||
|
return xyz, grid_size, length
|
||||||
|
|
||||||
|
|
||||||
|
class VanillaVolumeDecoder:
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
latents: torch.FloatTensor,
|
||||||
|
geo_decoder: Callable,
|
||||||
|
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||||
|
num_chunks: int = 10000,
|
||||||
|
octree_resolution: int = None,
|
||||||
|
enable_pbar: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
device = latents.device
|
||||||
|
dtype = latents.dtype
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
|
||||||
|
# 1. generate query points
|
||||||
|
if isinstance(bounds, float):
|
||||||
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
|
||||||
|
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||||
|
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||||
|
bbox_min=bbox_min,
|
||||||
|
bbox_max=bbox_max,
|
||||||
|
octree_resolution=octree_resolution,
|
||||||
|
indexing="ij"
|
||||||
|
)
|
||||||
|
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||||
|
|
||||||
|
# 2. latents to 3d volume
|
||||||
|
batch_logits = []
|
||||||
|
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||||
|
disable=not enable_pbar):
|
||||||
|
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||||
|
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||||
|
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
|
||||||
|
grid_logits = torch.cat(batch_logits, dim=1)
|
||||||
|
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||||
|
|
||||||
|
return grid_logits
|
||||||
|
|
||||||
|
|
||||||
|
class FourierEmbedder(nn.Module):
|
||||||
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||||
|
each feature dimension of `x[..., i]` into:
|
||||||
|
[
|
||||||
|
sin(x[..., i]),
|
||||||
|
sin(f_1*x[..., i]),
|
||||||
|
sin(f_2*x[..., i]),
|
||||||
|
...
|
||||||
|
sin(f_N * x[..., i]),
|
||||||
|
cos(x[..., i]),
|
||||||
|
cos(f_1*x[..., i]),
|
||||||
|
cos(f_2*x[..., i]),
|
||||||
|
...
|
||||||
|
cos(f_N * x[..., i]),
|
||||||
|
x[..., i] # only present if include_input is True.
|
||||||
|
], here f_i is the frequency.
|
||||||
|
|
||||||
|
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||||
|
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||||
|
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_freqs (int): the number of frequencies, default is 6;
|
||||||
|
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||||
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||||
|
input_dim (int): the input dimension, default is 3;
|
||||||
|
include_input (bool): include the input tensor or not, default is True.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||||
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||||
|
|
||||||
|
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||||
|
otherwise, it is input_dim * num_freqs * 2.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_freqs: int = 6,
|
||||||
|
logspace: bool = True,
|
||||||
|
input_dim: int = 3,
|
||||||
|
include_input: bool = True,
|
||||||
|
include_pi: bool = True) -> None:
|
||||||
|
|
||||||
|
"""The initialization"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if logspace:
|
||||||
|
frequencies = 2.0 ** torch.arange(
|
||||||
|
num_freqs,
|
||||||
|
dtype=torch.float32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
frequencies = torch.linspace(
|
||||||
|
1.0,
|
||||||
|
2.0 ** (num_freqs - 1),
|
||||||
|
num_freqs,
|
||||||
|
dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_pi:
|
||||||
|
frequencies *= torch.pi
|
||||||
|
|
||||||
|
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||||
|
self.include_input = include_input
|
||||||
|
self.num_freqs = num_freqs
|
||||||
|
|
||||||
|
self.out_dim = self.get_dims(input_dim)
|
||||||
|
|
||||||
|
def get_dims(self, input_dim):
|
||||||
|
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||||
|
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||||
|
|
||||||
|
return out_dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
""" Forward process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: tensor of shape [..., dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||||
|
where temp is 1 if include_input is True and 0 otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.num_freqs > 0:
|
||||||
|
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
|
||||||
|
if self.include_input:
|
||||||
|
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||||
|
else:
|
||||||
|
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionProcessor:
|
||||||
|
def __call__(self, attn, q, k, v):
|
||||||
|
out = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.scale_by_keep = scale_by_keep
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.drop_prob == 0. or not self.training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - self.drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0 and self.scale_by_keep:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
return x * random_tensor
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, *,
|
||||||
|
width: int,
|
||||||
|
expand_ratio: int = 4,
|
||||||
|
output_width: int = None,
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.c_fc = ops.Linear(width, width * expand_ratio)
|
||||||
|
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class QKVMultiheadCrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
heads: int,
|
||||||
|
width=None,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=ops.LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
self.attn_processor = CrossAttentionProcessor()
|
||||||
|
|
||||||
|
def forward(self, q, kv):
|
||||||
|
_, n_ctx, _ = q.shape
|
||||||
|
bs, n_data, width = kv.shape
|
||||||
|
attn_ch = width // self.heads // 2
|
||||||
|
q = q.view(bs, n_ctx, self.heads, -1)
|
||||||
|
kv = kv.view(bs, n_data, self.heads, -1)
|
||||||
|
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||||
|
out = self.attn_processor(self, q, k, v)
|
||||||
|
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadCrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
data_width: Optional[int] = None,
|
||||||
|
norm_layer=ops.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
kv_cache: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.heads = heads
|
||||||
|
self.data_width = width if data_width is None else data_width
|
||||||
|
self.c_q = ops.Linear(width, width, bias=qkv_bias)
|
||||||
|
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||||
|
self.c_proj = ops.Linear(width, width)
|
||||||
|
self.attention = QKVMultiheadCrossAttention(
|
||||||
|
heads=heads,
|
||||||
|
width=width,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
self.kv_cache = kv_cache
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
def forward(self, x, data):
|
||||||
|
x = self.c_q(x)
|
||||||
|
if self.kv_cache:
|
||||||
|
if self.data is None:
|
||||||
|
self.data = self.c_kv(data)
|
||||||
|
logging.info('Save kv cache,this should be called only once for one mesh')
|
||||||
|
data = self.data
|
||||||
|
else:
|
||||||
|
data = self.c_kv(data)
|
||||||
|
x = self.attention(x, data)
|
||||||
|
x = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCrossAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_expand_ratio: int = 4,
|
||||||
|
data_width: Optional[int] = None,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer=ops.LayerNorm,
|
||||||
|
qk_norm: bool = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if data_width is None:
|
||||||
|
data_width = width
|
||||||
|
|
||||||
|
self.attn = MultiheadCrossAttention(
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
data_width=data_width,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||||
|
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||||
|
x = x + self.mlp(self.ln_3(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QKVMultiheadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
heads: int,
|
||||||
|
width=None,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=ops.LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, qkv):
|
||||||
|
bs, n_ctx, width = qkv.shape
|
||||||
|
attn_ch = width // self.heads // 3
|
||||||
|
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||||
|
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||||
|
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool,
|
||||||
|
norm_layer=ops.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.heads = heads
|
||||||
|
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||||
|
self.c_proj = ops.Linear(width, width)
|
||||||
|
self.attention = QKVMultiheadAttention(
|
||||||
|
heads=heads,
|
||||||
|
width=width,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.c_qkv(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
x = self.drop_path(self.c_proj(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer=ops.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = MultiheadAttention(
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
drop_path_rate=drop_path_rate
|
||||||
|
)
|
||||||
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||||
|
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x + self.attn(self.ln_1(x))
|
||||||
|
x = x + self.mlp(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer=ops.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.resblocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
drop_path_rate=drop_path_rate
|
||||||
|
)
|
||||||
|
for _ in range(layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
for block in self.resblocks:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionDecoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
out_channels: int,
|
||||||
|
fourier_embedder: FourierEmbedder,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_expand_ratio: int = 4,
|
||||||
|
downsample_ratio: int = 1,
|
||||||
|
enable_ln_post: bool = True,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
label_type: str = "binary"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.enable_ln_post = enable_ln_post
|
||||||
|
self.fourier_embedder = fourier_embedder
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||||
|
if self.downsample_ratio != 1:
|
||||||
|
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||||
|
if self.enable_ln_post == False:
|
||||||
|
qk_norm = False
|
||||||
|
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||||
|
width=width,
|
||||||
|
mlp_expand_ratio=mlp_expand_ratio,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_ln_post:
|
||||||
|
self.ln_post = ops.LayerNorm(width)
|
||||||
|
self.output_proj = ops.Linear(width, out_channels)
|
||||||
|
self.label_type = label_type
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||||
|
if query_embeddings is None:
|
||||||
|
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||||
|
self.count += query_embeddings.shape[1]
|
||||||
|
if self.downsample_ratio != 1:
|
||||||
|
latents = self.latents_proj(latents)
|
||||||
|
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||||
|
if self.enable_ln_post:
|
||||||
|
x = self.ln_post(x)
|
||||||
|
occ = self.output_proj(x)
|
||||||
|
return occ
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeVAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
embed_dim: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
num_decoder_layers: int,
|
||||||
|
geo_decoder_downsample_ratio: int = 1,
|
||||||
|
geo_decoder_mlp_expand_ratio: int = 4,
|
||||||
|
geo_decoder_ln_post: bool = True,
|
||||||
|
num_freqs: int = 8,
|
||||||
|
include_pi: bool = True,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
label_type: str = "binary",
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
scale_factor: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||||
|
|
||||||
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||||
|
|
||||||
|
self.post_kl = ops.Linear(embed_dim, width)
|
||||||
|
|
||||||
|
self.transformer = Transformer(
|
||||||
|
width=width,
|
||||||
|
layers=num_decoder_layers,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
drop_path_rate=drop_path_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
self.geo_decoder = CrossAttentionDecoder(
|
||||||
|
fourier_embedder=self.fourier_embedder,
|
||||||
|
out_channels=1,
|
||||||
|
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||||
|
downsample_ratio=geo_decoder_downsample_ratio,
|
||||||
|
enable_ln_post=self.geo_decoder_ln_post,
|
||||||
|
width=width // geo_decoder_downsample_ratio,
|
||||||
|
heads=heads // geo_decoder_downsample_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
label_type=label_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.volume_decoder = VanillaVolumeDecoder()
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
def decode(self, latents, **kwargs):
|
||||||
|
latents = self.post_kl(latents.movedim(-2, -1))
|
||||||
|
latents = self.transformer(latents)
|
||||||
|
|
||||||
|
bounds = kwargs.get("bounds", 1.01)
|
||||||
|
num_chunks = kwargs.get("num_chunks", 8000)
|
||||||
|
octree_resolution = kwargs.get("octree_resolution", 256)
|
||||||
|
enable_pbar = kwargs.get("enable_pbar", True)
|
||||||
|
|
||||||
|
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||||
|
return grid_logits
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return None
|
@ -36,6 +36,7 @@ import comfy.ldm.hunyuan_video.model
|
|||||||
import comfy.ldm.cosmos.model
|
import comfy.ldm.cosmos.model
|
||||||
import comfy.ldm.lumina.model
|
import comfy.ldm.lumina.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
|
import comfy.ldm.hunyuan3d.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1013,3 +1014,18 @@ class WAN21(BaseModel):
|
|||||||
if clip_vision_output is not None:
|
if clip_vision_output is not None:
|
||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Hunyuan3Dv2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 5.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
return out
|
||||||
|
@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "t2v"
|
dit_config["model_type"] = "t2v"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
|
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "hunyuan3d2"
|
||||||
|
dit_config["in_channels"] = in_shape[1]
|
||||||
|
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["hidden_size"] = in_shape[0]
|
||||||
|
dit_config["mlp_ratio"] = 4.0
|
||||||
|
dit_config["num_heads"] = 16
|
||||||
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["qkv_bias"] = True
|
||||||
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
15
comfy/sd.py
15
comfy/sd.py
@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
|
|||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
import comfy.ldm.cosmos.vae
|
import comfy.ldm.cosmos.vae
|
||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -412,6 +413,16 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||||
|
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||||
|
self.latent_dim = 1
|
||||||
|
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||||
|
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
||||||
|
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
||||||
|
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||||
|
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||||
|
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -498,7 +509,7 @@ class VAE:
|
|||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in, vae_options={}):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
@ -510,7 +521,7 @@ class VAE:
|
|||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
|
@ -959,6 +959,42 @@ class WAN21_I2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan3d2",
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
|
||||||
|
latent_format = latent_formats.Hunyuan3Dv2
|
||||||
|
|
||||||
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "model."}
|
||||||
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Hunyuan3Dv2(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan3d2",
|
||||||
|
"depth": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
410
comfy_extras/nodes_hunyuan3d.py
Normal file
410
comfy_extras/nodes_hunyuan3d.py
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
import numpy as np
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
|
||||||
|
import folder_paths
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyLatentHunyuan3Dv2:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/3d"
|
||||||
|
|
||||||
|
def generate(self, resolution, batch_size):
|
||||||
|
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||||
|
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||||
|
|
||||||
|
|
||||||
|
class Hunyuan3Dv2Conditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, clip_vision_output):
|
||||||
|
embeds = clip_vision_output.last_hidden_state
|
||||||
|
positive = [[embeds, {}]]
|
||||||
|
negative = [[torch.zeros_like(embeds), {}]]
|
||||||
|
return (positive, negative)
|
||||||
|
|
||||||
|
|
||||||
|
class Hunyuan3Dv2ConditioningMultiView:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {},
|
||||||
|
"optional": {"front": ("CLIP_VISION_OUTPUT",),
|
||||||
|
"left": ("CLIP_VISION_OUTPUT",),
|
||||||
|
"back": ("CLIP_VISION_OUTPUT",),
|
||||||
|
"right": ("CLIP_VISION_OUTPUT",), }}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, front=None, left=None, back=None, right=None):
|
||||||
|
all_embeds = [front, left, back, right]
|
||||||
|
out = []
|
||||||
|
pos_embeds = None
|
||||||
|
for i, e in enumerate(all_embeds):
|
||||||
|
if e is not None:
|
||||||
|
if pos_embeds is None:
|
||||||
|
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
|
||||||
|
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
|
||||||
|
|
||||||
|
embeds = torch.cat(out, dim=1)
|
||||||
|
positive = [[embeds, {}]]
|
||||||
|
negative = [[torch.zeros_like(embeds), {}]]
|
||||||
|
return (positive, negative)
|
||||||
|
|
||||||
|
|
||||||
|
class VOXEL:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecodeHunyuan3D:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"samples": ("LATENT", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
|
||||||
|
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("VOXEL",)
|
||||||
|
FUNCTION = "decode"
|
||||||
|
|
||||||
|
CATEGORY = "latent/3d"
|
||||||
|
|
||||||
|
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||||
|
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||||
|
return (voxels, )
|
||||||
|
|
||||||
|
|
||||||
|
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
voxels = voxels.to(device)
|
||||||
|
|
||||||
|
binary = (voxels > threshold).float()
|
||||||
|
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
||||||
|
|
||||||
|
D, H, W = binary.shape
|
||||||
|
|
||||||
|
neighbors = torch.tensor([
|
||||||
|
[0, 0, 1],
|
||||||
|
[0, 0, -1],
|
||||||
|
[0, 1, 0],
|
||||||
|
[0, -1, 0],
|
||||||
|
[1, 0, 0],
|
||||||
|
[-1, 0, 0]
|
||||||
|
], device=device)
|
||||||
|
|
||||||
|
z, y, x = torch.meshgrid(
|
||||||
|
torch.arange(D, device=device),
|
||||||
|
torch.arange(H, device=device),
|
||||||
|
torch.arange(W, device=device),
|
||||||
|
indexing='ij'
|
||||||
|
)
|
||||||
|
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
||||||
|
|
||||||
|
solid_mask = binary.flatten() > 0
|
||||||
|
solid_indices = voxel_indices[solid_mask]
|
||||||
|
|
||||||
|
corner_offsets = [
|
||||||
|
torch.tensor([
|
||||||
|
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
|
||||||
|
], device=device),
|
||||||
|
torch.tensor([
|
||||||
|
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
|
||||||
|
], device=device),
|
||||||
|
torch.tensor([
|
||||||
|
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
|
||||||
|
], device=device),
|
||||||
|
torch.tensor([
|
||||||
|
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
|
||||||
|
], device=device),
|
||||||
|
torch.tensor([
|
||||||
|
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
|
||||||
|
], device=device),
|
||||||
|
torch.tensor([
|
||||||
|
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
|
||||||
|
], device=device)
|
||||||
|
]
|
||||||
|
|
||||||
|
all_vertices = []
|
||||||
|
all_indices = []
|
||||||
|
|
||||||
|
vertex_count = 0
|
||||||
|
|
||||||
|
for face_idx, offset in enumerate(neighbors):
|
||||||
|
neighbor_indices = solid_indices + offset
|
||||||
|
|
||||||
|
padded_indices = neighbor_indices + 1
|
||||||
|
|
||||||
|
is_exposed = padded[
|
||||||
|
padded_indices[:, 0],
|
||||||
|
padded_indices[:, 1],
|
||||||
|
padded_indices[:, 2]
|
||||||
|
] == 0
|
||||||
|
|
||||||
|
if not is_exposed.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
exposed_indices = solid_indices[is_exposed]
|
||||||
|
|
||||||
|
corners = corner_offsets[face_idx].unsqueeze(0)
|
||||||
|
|
||||||
|
face_vertices = exposed_indices.unsqueeze(1) + corners
|
||||||
|
|
||||||
|
all_vertices.append(face_vertices.reshape(-1, 3))
|
||||||
|
|
||||||
|
num_faces = exposed_indices.shape[0]
|
||||||
|
face_indices = torch.arange(
|
||||||
|
vertex_count,
|
||||||
|
vertex_count + 4 * num_faces,
|
||||||
|
device=device
|
||||||
|
).reshape(-1, 4)
|
||||||
|
|
||||||
|
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 1]], dim=1))
|
||||||
|
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 3], face_indices[:, 2]], dim=1))
|
||||||
|
|
||||||
|
vertex_count += 4 * num_faces
|
||||||
|
|
||||||
|
vertices = torch.cat(all_vertices, dim=0)
|
||||||
|
faces = torch.cat(all_indices, dim=0)
|
||||||
|
|
||||||
|
v_min = 0
|
||||||
|
v_max = max(voxels.shape)
|
||||||
|
|
||||||
|
vertices = vertices - (v_min + v_max) / 2
|
||||||
|
|
||||||
|
scale = (v_max - v_min) / 2
|
||||||
|
if scale > 0:
|
||||||
|
vertices = vertices / scale
|
||||||
|
|
||||||
|
return vertices, faces
|
||||||
|
|
||||||
|
|
||||||
|
class MESH:
|
||||||
|
def __init__(self, vertices, faces):
|
||||||
|
self.vertices = vertices
|
||||||
|
self.faces = faces
|
||||||
|
|
||||||
|
|
||||||
|
class VoxelToMeshBasic:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"voxel": ("VOXEL", ),
|
||||||
|
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MESH",)
|
||||||
|
FUNCTION = "decode"
|
||||||
|
|
||||||
|
CATEGORY = "3d"
|
||||||
|
|
||||||
|
def decode(self, voxel, threshold):
|
||||||
|
vertices = []
|
||||||
|
faces = []
|
||||||
|
for x in voxel.data:
|
||||||
|
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
|
||||||
|
vertices.append(v)
|
||||||
|
faces.append(f)
|
||||||
|
|
||||||
|
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||||
|
|
||||||
|
|
||||||
|
def save_glb(vertices, faces, filepath, metadata=None):
|
||||||
|
"""
|
||||||
|
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
||||||
|
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
|
||||||
|
filepath: str - Output filepath (should end with .glb)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Convert tensors to numpy arrays
|
||||||
|
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
||||||
|
faces_np = faces.cpu().numpy().astype(np.uint32)
|
||||||
|
|
||||||
|
vertices_buffer = vertices_np.tobytes()
|
||||||
|
indices_buffer = faces_np.tobytes()
|
||||||
|
|
||||||
|
def pad_to_4_bytes(buffer):
|
||||||
|
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||||
|
return buffer + b'\x00' * padding_length
|
||||||
|
|
||||||
|
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||||
|
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||||
|
|
||||||
|
buffer_data = vertices_buffer_padded + indices_buffer_padded
|
||||||
|
|
||||||
|
vertices_byte_length = len(vertices_buffer)
|
||||||
|
vertices_byte_offset = 0
|
||||||
|
indices_byte_length = len(indices_buffer)
|
||||||
|
indices_byte_offset = len(vertices_buffer_padded)
|
||||||
|
|
||||||
|
gltf = {
|
||||||
|
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||||
|
"buffers": [
|
||||||
|
{
|
||||||
|
"byteLength": len(buffer_data)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"bufferViews": [
|
||||||
|
{
|
||||||
|
"buffer": 0,
|
||||||
|
"byteOffset": vertices_byte_offset,
|
||||||
|
"byteLength": vertices_byte_length,
|
||||||
|
"target": 34962 # ARRAY_BUFFER
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"buffer": 0,
|
||||||
|
"byteOffset": indices_byte_offset,
|
||||||
|
"byteLength": indices_byte_length,
|
||||||
|
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"accessors": [
|
||||||
|
{
|
||||||
|
"bufferView": 0,
|
||||||
|
"byteOffset": 0,
|
||||||
|
"componentType": 5126, # FLOAT
|
||||||
|
"count": len(vertices_np),
|
||||||
|
"type": "VEC3",
|
||||||
|
"max": vertices_np.max(axis=0).tolist(),
|
||||||
|
"min": vertices_np.min(axis=0).tolist()
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"bufferView": 1,
|
||||||
|
"byteOffset": 0,
|
||||||
|
"componentType": 5125, # UNSIGNED_INT
|
||||||
|
"count": faces_np.size,
|
||||||
|
"type": "SCALAR"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"meshes": [
|
||||||
|
{
|
||||||
|
"primitives": [
|
||||||
|
{
|
||||||
|
"attributes": {
|
||||||
|
"POSITION": 0
|
||||||
|
},
|
||||||
|
"indices": 1,
|
||||||
|
"mode": 4 # TRIANGLES
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"mesh": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"scenes": [
|
||||||
|
{
|
||||||
|
"nodes": [0]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"scene": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
gltf["asset"]["extras"] = metadata
|
||||||
|
|
||||||
|
# Convert the JSON to bytes
|
||||||
|
gltf_json = json.dumps(gltf).encode('utf8')
|
||||||
|
|
||||||
|
def pad_json_to_4_bytes(buffer):
|
||||||
|
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||||
|
return buffer + b' ' * padding_length
|
||||||
|
|
||||||
|
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
||||||
|
|
||||||
|
# Create the GLB header
|
||||||
|
# Magic glTF
|
||||||
|
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
||||||
|
|
||||||
|
# Create JSON chunk header (chunk type 0)
|
||||||
|
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
||||||
|
|
||||||
|
# Create BIN chunk header (chunk type 1)
|
||||||
|
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
||||||
|
|
||||||
|
# Write the GLB file
|
||||||
|
with open(filepath, 'wb') as f:
|
||||||
|
f.write(glb_header)
|
||||||
|
f.write(json_chunk_header)
|
||||||
|
f.write(gltf_json_padded)
|
||||||
|
f.write(bin_chunk_header)
|
||||||
|
f.write(buffer_data)
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
class SaveGLB:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"mesh": ("MESH", ),
|
||||||
|
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "3d"
|
||||||
|
|
||||||
|
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
|
results = []
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if not args.disable_metadata:
|
||||||
|
if prompt is not None:
|
||||||
|
metadata["prompt"] = json.dumps(prompt)
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
for i in range(mesh.vertices.shape[0]):
|
||||||
|
f = f"{filename}_{counter:05}_.glb"
|
||||||
|
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||||
|
results.append({
|
||||||
|
"filename": f,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": "output"
|
||||||
|
})
|
||||||
|
counter += 1
|
||||||
|
return {"ui": {"3d": results}}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
|
||||||
|
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
|
||||||
|
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||||
|
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||||
|
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||||
|
"SaveGLB": SaveGLB,
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user