Add PixArt model support (#6055)

* PixArt initial version

* PixArt Diffusers convert logic

* pos_emb and interpolation logic

* Reduce  duplicate code

* Formatting

* Use optimized attention

* Edit empty token logic

* Basic PixArt LoRA support

* Fix aspect ratio logic

* PixArtAlpha text encode with conds

* Use same detection key logic for PixArt diffusers
This commit is contained in:
City 2024-12-20 21:25:00 +01:00 committed by GitHub
parent 418eb7062d
commit bddb02660c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1079 additions and 5 deletions

382
comfy/ldm/pixart/blocks.py Normal file
View File

@ -0,0 +1,382 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy import model_management
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
if model_management.xformers_enabled():
import xformers.ops
if int((xformers.__version__).split(".")[2]) >= 28:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
else:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
# TODO: xformers needs separate mask logic here
if model_management.xformers_enabled():
attn_bias = None
if mask is not None:
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
else:
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
attn_mask = None
if mask is not None and len(mask) > 1:
# Create equivalent of xformer diagonal block mask, still only correct for square masks
# But depth doesn't matter as tensors can expand in that dimension
attn_mask_template = torch.ones(
[q.shape[2] // B, mask[0]],
dtype=torch.bool,
device=q.device
)
attn_mask = torch.block_diag(attn_mask_template)
# create a mask on the diagonal for each mask in the batch
for _ in range(B - 1):
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionKVCompress(nn.Module):
"""Multi-head Attention block with KV token compression and qk norm."""
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
self.sr_ratio = sr_ratio
if sr_ratio > 1 and sampling == 'conv':
# Avg Conv Init.
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
# self.sr.weight.data.fill_(1/sr_ratio**2)
# self.sr.bias.data.zero_()
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
if qk_norm:
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == 'uniform_every':
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == 'ave':
tensor = F.interpolate(
tensor, scale_factor=1 / scale_factor, mode='nearest'
).permute(0, 2, 3, 1)
elif sampling == 'uniform':
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == 'conv':
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x, mask=None, HW=None, block_id=None):
B, N, C = x.shape # 2 4096 1152
new_N = N
if HW is None:
H = W = int(N ** 0.5)
else:
H, W = HW
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
dtype = q.dtype
q = self.q_norm(q)
k = self.k_norm(k)
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
if mask is not None:
raise NotImplementedError("Attn mask logic not added for self attention")
# This is never called at the moment
# attn_bias = None
# if mask is not None:
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
# attention 2
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
x = x.view(B, N, C)
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.out_channels = out_channels
def forward(self, x, t):
dtype = x.dtype
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x.to(dtype))
return x
class MaskFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs//s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
self.uncond_prob = uncond_prob
def token_drop(self, global_caption, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return global_caption, caption
def forward(self, caption, train, force_drop_ids=None):
assert caption.shape[2: ] == self.y_embedding.shape
global_caption = caption.mean(dim=2).squeeze()
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
y_embed = self.proj(global_caption)
return y_embed, caption

201
comfy/ldm/pixart/pixart.py Normal file
View File

@ -0,0 +1,201 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
)
from comfy.ldm.modules.diffusionmodules.mmdit import PatchEmbed, TimestepEmbedder, Mlp, get_1d_sincos_pos_embed_from_grid_torch
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=1.0,
pe_precision=None,
config=None,
model_max_length=120,
qk_norm=False,
kv_compress_config=None,
**kwargs,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.depth = depth
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size // self.patch_size
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(input_size // patch_size, input_size // patch_size),
sampling=self.kv_compress_config['sampling'],
sr_ratio=int(
self.kv_compress_config['scale_factor']
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
def forward_raw(self, x, t, y, mask=None, data_info=None):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = t.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, y=None, **kwargs):
"""
Forward pass that adapts comfy input to original forward function
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
timesteps: (N,) tensor of diffusion timesteps
context: (N, 1, 120, C) conditioning
y: extra conditioning.
"""
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
y = context.to(self.dtype),
)
## only return EPS
out = out.to(torch.float)
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
return eps
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb

View File

@ -0,0 +1,246 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp
from .pixart import PixArt, get_2d_sincos_pos_embed_torch
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(PixArt):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.PixArt):
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
key_map[key_lora] = to
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
key_map[key_lora] = to
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
key_map[key_lora] = to
if isinstance(model, comfy.model_base.HunyuanDiT):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):

View File

@ -26,6 +26,7 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.genmo.joint_model.asymm_models_joint
import comfy.ldm.aura.mmdit
import comfy.ldm.pixart.pixartms
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
@ -718,6 +719,21 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
width = kwargs.get("width", None)
height = kwargs.get("height", None)
if width is not None and height is not None:
out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]]))
out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]]))
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)

View File

@ -203,11 +203,42 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
# PixArt diffusers
return None
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}
dit_config["num_heads"] = 16
dit_config["patch_size"] = patch_size
dit_config["hidden_size"] = 1152
dit_config["in_channels"] = 4
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
dit_config["model_max_length"] = state_dict[y_key].shape[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
if ar_key in state_dict_keys:
dit_config["image_model"] = "pixart_alpha"
dit_config["micro_condition"] = True
else:
dit_config["image_model"] = "pixart_sigma"
dit_config["micro_condition"] = False
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -573,6 +604,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')

View File

@ -27,6 +27,7 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
@ -604,6 +605,8 @@ class CLIPType(Enum):
MOCHI = 7
LTXV = 8
HUNYUAN_VIDEO = 9
PIXART = 10
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@ -696,6 +699,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer

View File

@ -37,8 +37,12 @@ class ClipTokenWeightEncoder:
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
if hasattr(self, "gen_empty_tokens"):
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
else:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
print(to_encode)
o = self.encode(to_encode)
out, pooled = o[:2]

View File

@ -8,6 +8,7 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
@ -592,6 +593,37 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
class PixArtAlpha(supported_models_base.BASE):
unet_config = {
"image_model": "pixart_alpha",
}
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
unet_extra_config = {}
latent_format = latent_formats.SD15
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.PixArt(self, device=device)
return out.eval()
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
class PixArtSigma(PixArtAlpha):
unet_config = {
"image_model": "pixart_sigma",
}
latent_format = latent_formats.SDXL
class HunyuanDiT(supported_models_base.BASE):
unet_config = {
"image_model": "hydit",
@ -787,6 +819,6 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
models += [SVD_img2vid]

View File

@ -0,0 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

View File

@ -386,6 +386,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map
PIXART_MAP_BASIC = {
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
("y_embedder.y_embedding", "caption_projection.y_embedding"),
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
("t_block.1.weight", "adaln_single.linear.weight"),
("t_block.1.bias", "adaln_single.linear.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.scale_shift_table", "scale_shift_table"),
}
PIXART_MAP_BLOCK = {
("scale_shift_table", "scale_shift_table"),
("attn.proj.weight", "attn1.to_out.0.weight"),
("attn.proj.bias", "attn1.to_out.0.bias"),
("mlp.fc1.weight", "ff.net.0.proj.weight"),
("mlp.fc1.bias", "ff.net.0.proj.bias"),
("mlp.fc2.weight", "ff.net.2.weight"),
("mlp.fc2.bias", "ff.net.2.bias"),
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
}
def pixart_to_diffusers(mmdit_config, output_prefix=""):
key_map = {}
depth = mmdit_config.get("depth", 0)
offset = mmdit_config.get("hidden_size", 1152)
for i in range(depth):
block_from = "transformer_blocks.{}".format(i)
block_to = "{}blocks.{}".format(output_prefix, i)
for end in ("weight", "bias"):
s = "{}.attn1.".format(block_from)
qkv = "{}.attn.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
s = "{}.attn2.".format(block_from)
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = q
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
for k in PIXART_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
for k in PIXART_MAP_BASIC:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)

View File

@ -0,0 +1,24 @@
from nodes import MAX_RESOLUTION
class CLIPTextEncodePixArtAlpha:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
}

View File

@ -898,7 +898,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -918,6 +918,8 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
@ -2164,6 +2166,7 @@ def init_builtin_extra_nodes():
"nodes_stable3d.py",
"nodes_sdupscale.py",
"nodes_photomaker.py",
"nodes_pixart.py",
"nodes_cond.py",
"nodes_morphology.py",
"nodes_stable_cascade.py",