mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Mochi VAE encoder.
This commit is contained in:
parent
cc9cf6d1bd
commit
fabf449feb
@ -2,12 +2,16 @@
|
|||||||
#adapted to ComfyUI
|
#adapted to ComfyUI
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
@ -158,8 +162,10 @@ class ResBlock(nn.Module):
|
|||||||
*,
|
*,
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
attn_block: Optional[nn.Module] = None,
|
attn_block: Optional[nn.Module] = None,
|
||||||
padding_mode: str = "replicate",
|
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
prune_bottleneck: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -170,23 +176,23 @@ class ResBlock(nn.Module):
|
|||||||
nn.SiLU(inplace=True),
|
nn.SiLU(inplace=True),
|
||||||
PConv3d(
|
PConv3d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
out_channels=channels,
|
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
kernel_size=(3, 3, 3),
|
kernel_size=(3, 3, 3),
|
||||||
stride=(1, 1, 1),
|
stride=(1, 1, 1),
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
bias=True,
|
bias=bias,
|
||||||
# causal=causal,
|
causal=causal,
|
||||||
),
|
),
|
||||||
norm_fn(channels, affine=affine),
|
norm_fn(channels, affine=affine),
|
||||||
nn.SiLU(inplace=True),
|
nn.SiLU(inplace=True),
|
||||||
PConv3d(
|
PConv3d(
|
||||||
in_channels=channels,
|
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
out_channels=channels,
|
out_channels=channels,
|
||||||
kernel_size=(3, 3, 3),
|
kernel_size=(3, 3, 3),
|
||||||
stride=(1, 1, 1),
|
stride=(1, 1, 1),
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
bias=True,
|
bias=bias,
|
||||||
# causal=causal,
|
causal=causal,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -206,6 +212,81 @@ class ResBlock(nn.Module):
|
|||||||
return self.attn_block(x)
|
return self.attn_block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_dim: int = 32,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
out_bias: bool = True,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_heads = dim // head_dim
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||||
|
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute temporal self-attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor. Shape: [B, C, T, H, W].
|
||||||
|
chunk_size: Chunk size for large tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: Output tensor. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
|
if T == 1:
|
||||||
|
# No attention for single frame.
|
||||||
|
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||||
|
x = self.out(x)
|
||||||
|
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||||
|
|
||||||
|
# 1D temporal attention.
|
||||||
|
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||||
|
# Output: x with shape [B, num_heads, t, head_dim]
|
||||||
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
q = F.normalize(q, p=2, dim=-1)
|
||||||
|
k = F.normalize(k, p=2, dim=-1)
|
||||||
|
|
||||||
|
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
assert x.size(0) == q.size(0)
|
||||||
|
|
||||||
|
x = self.out(x)
|
||||||
|
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
**attn_kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = norm_fn(dim)
|
||||||
|
self.attn = Attention(dim, **attn_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + self.attn(self.norm(x))
|
||||||
|
|
||||||
|
|
||||||
class CausalUpsampleBlock(nn.Module):
|
class CausalUpsampleBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||||
assert has_attention is False #NOTE: if this is ever true add back the attention code.
|
attn_block = AttentionBlock(channels) if has_attention else None
|
||||||
|
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||||
attn_block = None #AttentionBlock(channels) if has_attention else None
|
|
||||||
|
|
||||||
return ResBlock(
|
|
||||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampleBlock(nn.Module):
|
class DownsampleBlock(nn.Module):
|
||||||
@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
|
|||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
# First layer in each block always uses replicate padding
|
||||||
padding_mode="replicate",
|
padding_mode="replicate",
|
||||||
bias=True,
|
bias=block_kwargs["bias"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -382,7 +459,7 @@ class Decoder(nn.Module):
|
|||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
first_block = [
|
first_block = [
|
||||||
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||||
] # Input layer.
|
] # Input layer.
|
||||||
# First set of blocks preserve channel count.
|
# First set of blocks preserve channel count.
|
||||||
for _ in range(num_res_blocks[-1]):
|
for _ in range(num_res_blocks[-1]):
|
||||||
@ -452,11 +529,165 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return self.output_proj(x).contiguous()
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
|
class LatentDistribution:
|
||||||
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||||
|
"""Initialize latent distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
assert mean.shape == logvar.shape
|
||||||
|
self.mean = mean
|
||||||
|
self.logvar = logvar
|
||||||
|
|
||||||
|
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||||
|
if temperature == 0.0:
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||||
|
else:
|
||||||
|
assert noise.device == self.mean.device
|
||||||
|
noise = noise.to(self.mean.dtype)
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||||
|
|
||||||
|
# Just Gaussian sample with no scaling of variance.
|
||||||
|
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
latent_dim: int,
|
||||||
|
temporal_reductions: List[int],
|
||||||
|
spatial_reductions: List[int],
|
||||||
|
prune_bottlenecks: List[bool],
|
||||||
|
has_attentions: List[bool],
|
||||||
|
affine: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
input_is_conv_1x1: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_reductions = temporal_reductions
|
||||||
|
self.spatial_reductions = spatial_reductions
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
self.fourier_features = FourierFeatures()
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
num_down_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == num_down_blocks + 2
|
||||||
|
|
||||||
|
layers = (
|
||||||
|
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||||
|
if not input_is_conv_1x1
|
||||||
|
else [Conv1x1(in_channels, ch[0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||||
|
assert len(has_attentions) == num_down_blocks + 2
|
||||||
|
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||||
|
prune_bottlenecks = prune_bottlenecks[1:]
|
||||||
|
has_attentions = has_attentions[1:]
|
||||||
|
|
||||||
|
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
layer = DownsampleBlock(
|
||||||
|
ch[i],
|
||||||
|
ch[i + 1],
|
||||||
|
num_res_blocks=num_res_blocks[i + 1],
|
||||||
|
temporal_reduction=temporal_reductions[i],
|
||||||
|
spatial_reduction=spatial_reductions[i],
|
||||||
|
prune_bottleneck=prune_bottlenecks[i],
|
||||||
|
has_attention=has_attentions[i],
|
||||||
|
affine=affine,
|
||||||
|
bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
# Additional blocks.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# Output layers.
|
||||||
|
self.output_norm = norm_fn(ch[-1])
|
||||||
|
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temporal_downsample(self):
|
||||||
|
return math.prod(self.temporal_reductions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spatial_downsample(self):
|
||||||
|
return math.prod(self.spatial_reductions)
|
||||||
|
|
||||||
|
def forward(self, x) -> LatentDistribution:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||||
|
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||||
|
logvar: Shape: [B, latent_dim, t, h, w].
|
||||||
|
"""
|
||||||
|
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||||
|
x = self.fourier_features(x)
|
||||||
|
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
x = self.output_norm(x)
|
||||||
|
x = F.silu(x, inplace=True)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
|
||||||
|
means, logvar = torch.chunk(x, 2, dim=1)
|
||||||
|
|
||||||
|
assert means.ndim == 5
|
||||||
|
assert logvar.shape == means.shape
|
||||||
|
assert means.size(1) == self.latent_dim
|
||||||
|
|
||||||
|
return LatentDistribution(means, logvar)
|
||||||
|
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = None #TODO once the model releases
|
self.encoder = Encoder(
|
||||||
|
in_channels=15,
|
||||||
|
base_channels=64,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
temporal_reductions=[1, 2, 3],
|
||||||
|
spatial_reductions=[2, 2, 2],
|
||||||
|
prune_bottlenecks=[False, False, False, False, False],
|
||||||
|
has_attentions=[False, True, True, True, True],
|
||||||
|
affine=True,
|
||||||
|
bias=True,
|
||||||
|
input_is_conv_1x1=True,
|
||||||
|
padding_mode="replicate"
|
||||||
|
)
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
out_channels=3,
|
out_channels=3,
|
||||||
base_channels=128,
|
base_channels=128,
|
||||||
@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.encoder(x)
|
return self.encoder(x).mode()
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x):
|
||||||
return self.decoder(x)
|
return self.decoder(x)
|
||||||
|
@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||||
|
SDP_BATCH_LIMIT = 2**15
|
||||||
|
else:
|
||||||
|
#TODO: other GPUs ?
|
||||||
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||||
|
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||||
|
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
22
comfy/sd.py
22
comfy/sd.py
@ -171,6 +171,7 @@ class VAE:
|
|||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
|
self.latent_dim = 2
|
||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@ -240,16 +241,22 @@ class VAE:
|
|||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
self.upscale_ratio = 2048
|
self.upscale_ratio = 2048
|
||||||
self.downscale_ratio = 2048
|
self.downscale_ratio = 2048
|
||||||
|
self.latent_dim = 1
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
|
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
|
||||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||||
|
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||||
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||||
self.latent_channels = 12
|
self.latent_channels = 12
|
||||||
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -361,17 +368,22 @@ class VAE:
|
|||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
|
if self.latent_dim == 3:
|
||||||
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / max(1, memory_used))
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
|
if samples is None:
|
||||||
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
Loading…
Reference in New Issue
Block a user