mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
568 lines
20 KiB
Python
568 lines
20 KiB
Python
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||
|
||
import comfy.ops
|
||
ops = comfy.ops.disable_weight_init
|
||
|
||
CACHE_T = 2
|
||
|
||
|
||
class CausalConv3d(ops.Conv3d):
|
||
"""
|
||
Causal 3d convolusion.
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||
self.padding[1], 2 * self.padding[0], 0)
|
||
self.padding = (0, 0, 0)
|
||
|
||
def forward(self, x, cache_x=None):
|
||
padding = list(self._padding)
|
||
if cache_x is not None and self._padding[4] > 0:
|
||
cache_x = cache_x.to(x.device)
|
||
x = torch.cat([cache_x, x], dim=2)
|
||
padding[4] -= cache_x.shape[2]
|
||
x = F.pad(x, padding)
|
||
|
||
return super().forward(x)
|
||
|
||
|
||
class RMS_norm(nn.Module):
|
||
|
||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||
super().__init__()
|
||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||
|
||
self.channel_first = channel_first
|
||
self.scale = dim**0.5
|
||
self.gamma = nn.Parameter(torch.ones(shape))
|
||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||
|
||
def forward(self, x):
|
||
return F.normalize(
|
||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||
|
||
|
||
class Upsample(nn.Upsample):
|
||
|
||
def forward(self, x):
|
||
"""
|
||
Fix bfloat16 support for nearest neighbor interpolation.
|
||
"""
|
||
return super().forward(x.float()).type_as(x)
|
||
|
||
|
||
class Resample(nn.Module):
|
||
|
||
def __init__(self, dim, mode):
|
||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||
'downsample3d')
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.mode = mode
|
||
|
||
# layers
|
||
if mode == 'upsample2d':
|
||
self.resample = nn.Sequential(
|
||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
elif mode == 'upsample3d':
|
||
self.resample = nn.Sequential(
|
||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||
|
||
elif mode == 'downsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
elif mode == 'downsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||
|
||
else:
|
||
self.resample = nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
b, c, t, h, w = x.size()
|
||
if self.mode == 'upsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = 'Rep'
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] != 'Rep':
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] == 'Rep':
|
||
cache_x = torch.cat([
|
||
torch.zeros_like(cache_x).to(cache_x.device),
|
||
cache_x
|
||
],
|
||
dim=2)
|
||
if feat_cache[idx] == 'Rep':
|
||
x = self.time_conv(x)
|
||
else:
|
||
x = self.time_conv(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
|
||
x = x.reshape(b, 2, c, t, h, w)
|
||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||
3)
|
||
x = x.reshape(b, c, t * 2, h, w)
|
||
t = x.shape[2]
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.resample(x)
|
||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||
|
||
if self.mode == 'downsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = x.clone()
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -1:, :, :].clone()
|
||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||
# # cache last frame of last two chunk
|
||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||
|
||
x = self.time_conv(
|
||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
return x
|
||
|
||
def init_weight(self, conv):
|
||
conv_weight = conv.weight
|
||
nn.init.zeros_(conv_weight)
|
||
c1, c2, t, h, w = conv_weight.size()
|
||
one_matrix = torch.eye(c1, c2)
|
||
init_matrix = one_matrix
|
||
nn.init.zeros_(conv_weight)
|
||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
||
conv.weight.data.copy_(conv_weight)
|
||
nn.init.zeros_(conv.bias.data)
|
||
|
||
def init_weight2(self, conv):
|
||
conv_weight = conv.weight.data
|
||
nn.init.zeros_(conv_weight)
|
||
c1, c2, t, h, w = conv_weight.size()
|
||
init_matrix = torch.eye(c1 // 2, c2)
|
||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||
conv.weight.data.copy_(conv_weight)
|
||
nn.init.zeros_(conv.bias.data)
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||
super().__init__()
|
||
self.in_dim = in_dim
|
||
self.out_dim = out_dim
|
||
|
||
# layers
|
||
self.residual = nn.Sequential(
|
||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||
if in_dim != out_dim else nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
h = self.shortcut(x)
|
||
for layer in self.residual:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x + h
|
||
|
||
|
||
class AttentionBlock(nn.Module):
|
||
"""
|
||
Causal self-attention with a single head.
|
||
"""
|
||
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
# layers
|
||
self.norm = RMS_norm(dim)
|
||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||
self.proj = ops.Conv2d(dim, dim, 1)
|
||
self.optimized_attention = vae_attention()
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
b, c, t, h, w = x.size()
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.norm(x)
|
||
# compute query, key, value
|
||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||
x = self.optimized_attention(q, k, v)
|
||
|
||
# output
|
||
x = self.proj(x)
|
||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||
return x + identity
|
||
|
||
|
||
class Encoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [1] + dim_mult]
|
||
scale = 1.0
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||
|
||
# downsample blocks
|
||
downsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
for _ in range(num_res_blocks):
|
||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
downsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# downsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'downsample3d' if temperal_downsample[
|
||
i] else 'downsample2d'
|
||
downsamples.append(Resample(out_dim, mode=mode))
|
||
scale /= 2.0
|
||
self.downsamples = nn.Sequential(*downsamples)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||
ResidualBlock(out_dim, out_dim, dropout))
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## downsamples
|
||
for layer in self.downsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
class Decoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_upsample=[False, True, True],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_upsample = temperal_upsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||
ResidualBlock(dims[0], dims[0], dropout))
|
||
|
||
# upsample blocks
|
||
upsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
if i == 1 or i == 2 or i == 3:
|
||
in_dim = in_dim // 2
|
||
for _ in range(num_res_blocks + 1):
|
||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
upsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# upsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||
upsamples.append(Resample(out_dim, mode=mode))
|
||
scale *= 2.0
|
||
self.upsamples = nn.Sequential(*upsamples)
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
## conv1
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## upsamples
|
||
for layer in self.upsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
def count_conv3d(model):
|
||
count = 0
|
||
for m in model.modules():
|
||
if isinstance(m, CausalConv3d):
|
||
count += 1
|
||
return count
|
||
|
||
|
||
class WanVAE(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
self.temperal_upsample = temperal_downsample[::-1]
|
||
|
||
# modules
|
||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_downsample, dropout)
|
||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_upsample, dropout)
|
||
|
||
def forward(self, x):
|
||
mu, log_var = self.encode(x)
|
||
z = self.reparameterize(mu, log_var)
|
||
x_recon = self.decode(z)
|
||
return x_recon, mu, log_var
|
||
|
||
def encode(self, x):
|
||
self.clear_cache()
|
||
## cache
|
||
t = x.shape[2]
|
||
iter_ = 1 + (t - 1) // 4
|
||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||
for i in range(iter_):
|
||
self._enc_conv_idx = [0]
|
||
if i == 0:
|
||
out = self.encoder(
|
||
x[:, :, :1, :, :],
|
||
feat_cache=self._enc_feat_map,
|
||
feat_idx=self._enc_conv_idx)
|
||
else:
|
||
out_ = self.encoder(
|
||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||
feat_cache=self._enc_feat_map,
|
||
feat_idx=self._enc_conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||
self.clear_cache()
|
||
return mu
|
||
|
||
def decode(self, z):
|
||
self.clear_cache()
|
||
# z: [b,c,t,h,w]
|
||
|
||
iter_ = z.shape[2]
|
||
x = self.conv2(z)
|
||
for i in range(iter_):
|
||
self._conv_idx = [0]
|
||
if i == 0:
|
||
out = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=self._feat_map,
|
||
feat_idx=self._conv_idx)
|
||
else:
|
||
out_ = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=self._feat_map,
|
||
feat_idx=self._conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
self.clear_cache()
|
||
return out
|
||
|
||
def reparameterize(self, mu, log_var):
|
||
std = torch.exp(0.5 * log_var)
|
||
eps = torch.randn_like(std)
|
||
return eps * std + mu
|
||
|
||
def sample(self, imgs, deterministic=False):
|
||
mu, log_var = self.encode(imgs)
|
||
if deterministic:
|
||
return mu
|
||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||
return mu + std * torch.randn_like(std)
|
||
|
||
def clear_cache(self):
|
||
self._conv_num = count_conv3d(self.decoder)
|
||
self._conv_idx = [0]
|
||
self._feat_map = [None] * self._conv_num
|
||
#cache encode
|
||
self._enc_conv_num = count_conv3d(self.encoder)
|
||
self._enc_conv_idx = [0]
|
||
self._enc_feat_map = [None] * self._enc_conv_num
|