Basic support for hidream i1 model.

This commit is contained in:
comfyanonymous 2025-04-15 17:35:05 -04:00
parent 6fc5dbd52a
commit 9ad792f927
9 changed files with 1087 additions and 2 deletions

828
comfy/ldm/hidream/model.py Normal file
View File

@ -0,0 +1,828 @@
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import einops
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
import torch.nn.functional as F
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
dtype=None, device=None, operations=None
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, latent):
latent = self.proj(latent)
return latent
class PooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, pooled_embed):
return self.pooled_embedder(pooled_embed)
class TimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, timesteps, wdtype):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
class OutEmbed(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, adaln_input):
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = self.linear(x)
return x
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
class HiDreamAttnProcessor_flashattn:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
dtype = image_tokens.dtype
batch_size = image_tokens.shape[0]
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
value_i = attn.to_v(image_tokens)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if image_tokens_masks is not None:
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
value_t = attn.to_v_t(text_tokens)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == rope.shape[-3] * 2:
query, key = apply_rope(query, key, rope)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, rope)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states
class HiDreamAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor = None,
out_dim: int = None,
single: bool = False,
dtype=None, device=None, operations=None
):
# super(Attention, self).__init__()
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
linear_cls = operations.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
if not single:
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.processor = processor
def forward(
self,
norm_image_tokens: torch.FloatTensor,
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.Tensor:
return self.processor(
self,
image_tokens = norm_image_tokens,
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
)
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
dtype=None, device=None, operations=None
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * (
(hidden_dim + multiple_of - 1) // multiple_of
)
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = 'softmax'
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
self.reset_parameters()
def reset_parameters(self) -> None:
pass
# import torch.nn.init as init
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
dtype=None, device=None, operations=None
):
super().__init__()
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
self.gate = MoEGate(
embed_dim = dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if True: # self.training: # TODO: check which branch performs faster
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
#y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
class BlockType:
TransformerBlock = 1
SingleTransformerBlock = 2
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = True,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_image_tokens,
image_tokens_masks,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
image_tokens = ff_output_i + image_tokens
return image_tokens
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
)
# nn.init.zeros_(self.adaLN_modulation[1].weight)
# nn.init.zeros_(self.adaLN_modulation[1].bias)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = False,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_image_tokens,
image_tokens_masks,
norm_text_tokens,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
text_tokens = gate_msa_t * attn_output_t + text_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
image_tokens = ff_output_i + image_tokens
text_tokens = ff_output_t + text_tokens
return image_tokens, text_tokens
class HiDreamImageBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
block_type: BlockType = BlockType.TransformerBlock,
dtype=None, device=None, operations=None
):
super().__init__()
block_classes = {
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
}
self.block = block_classes[block_type](
dim,
num_attention_heads,
attention_head_dim,
num_routed_experts,
num_activated_experts,
dtype=dtype, device=device, operations=operations
)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
return self.block(
image_tokens,
image_tokens_masks,
text_tokens,
adaln_input,
rope,
)
class HiDreamImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
image_model=None,
dtype=None, device=None, operations=None
):
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.num_layers = num_layers
self.num_single_layers = num_single_layers
self.gradient_checkpointing = False
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.inner_dim = self.num_attention_heads * self.attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.x_embedder = PatchEmbed(
patch_size = patch_size,
in_channels = in_channels,
out_channels = self.inner_dim,
dtype=dtype, device=device, operations=operations
)
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.TransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.SingleTransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_single_layers)
]
)
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.patch_size, p2=self.patch_size)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.patch_size * self.patch_size
if isinstance(x, torch.Tensor):
B = x.shape[0]
device = x.device
dtype = x.dtype
else:
B = len(x)
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0:img_size[0] * img_size[1]] = 1
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
elif isinstance(x, torch.Tensor):
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
control = None,
transformer_options = {},
) -> torch.Tensor:
hidden_states = x
timesteps = t
pooled_embeds = y
T5_encoder_hidden_states = context
img_sizes = None
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
adaln_input = timesteps + p_embedder
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if image_tokens_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
hidden_states = self.x_embedder(hidden_states)
# T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
3,
device=img_ids.device, dtype=img_ids.dtype
)
ids = torch.cat((img_ids, txt_ids), dim=1)
rope = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
image_tokens_masks = image_tokens_masks,
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if image_tokens_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
)
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states = block(
image_tokens=hidden_states,
image_tokens_masks=image_tokens_masks,
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, adaln_input)
output = self.unpatchify(output, img_sizes)
return -output

View File

@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.model_management
import comfy.patcher_extension
@ -1056,3 +1057,20 @@ class Hunyuan3Dv2(BaseModel):
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
if conditioning_llama3 is not None:
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
return out

View File

@ -338,6 +338,25 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
dit_config["attention_head_dim"] = 128
dit_config["axes_dims_rope"] = [64, 32, 32]
dit_config["caption_channels"] = [4096, 4096]
dit_config["max_resolution"] = [128, 128]
dit_config["in_channels"] = 16
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
dit_config["num_attention_heads"] = 20
dit_config["num_routed_experts"] = 4
dit_config["num_activated_experts"] = 2
dit_config["num_layers"] = 16
dit_config["num_single_layers"] = 32
dit_config["out_channels"] = 16
dit_config["patch_size"] = 2
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@ -263,6 +263,9 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
class RMSNorm(disable_weight_init.RMSNorm):
comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True

View File

@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.model_patcher
import comfy.lora
@ -853,6 +854,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif len(clip_data) == 4:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
parameters = 0
for c in clip_data:

View File

@ -1025,6 +1025,36 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
}
sampling_settings = {
"shift": 3.0,
}
sampling_settings = {
}
# memory_usage_factor = 1.2 # TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HiDream(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
models += [SVD_img2vid]

View File

@ -0,0 +1,150 @@
from . import hunyuan_video
from . import sd3_clip
from comfy import sd1_clip
from comfy import sdxl_clip
import comfy.model_management
import torch
import logging
class HiDreamTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data)
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HiDreamTEModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
if llama:
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
if "vocab_size" not in model_options:
model_options["vocab_size"] = 128256
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
self.dtypes.add(dtype_llama)
else:
self.llama = None
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
def set_clip_options(self, options):
if self.clip_l is not None:
self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options)
if self.llama is not None:
self.llama.set_clip_options(options)
def reset_clip_options(self):
if self.clip_l is not None:
self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options()
if self.llama is not None:
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_llama = token_weight_pairs["llama"]
lg_out = None
pooled = None
extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
else:
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
if self.llama is not None:
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
ll_out, ll_pooled = ll_output[:2]
ll_out = ll_out[:, 1:]
if t5_out is None:
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
if ll_out is None:
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
extra["conditioning_llama3"] = ll_out
return t5_out, pooled, extra
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
return self.t5xxl.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -0,0 +1,32 @@
import folder_paths
import comfy.sd
import comfy.model_management
class QuadrupleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
}

View File

@ -2280,7 +2280,8 @@ def init_builtin_extra_nodes():
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
"nodes_optimalsteps.py"
"nodes_optimalsteps.py",
"nodes_hidream.py"
]
import_failed = []