mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-05-10 14:36:07 +00:00
Fix stable cascade VAE on some lowvram machines.
This commit is contained in:
parent
29832b3b61
commit
0952569493
@ -19,6 +19,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class vector_quantize(Function):
|
class vector_quantize(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
|||||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||||
self.depthwise = nn.Sequential(
|
self.depthwise = nn.Sequential(
|
||||||
nn.ReplicationPad2d(1),
|
nn.ReplicationPad2d(1),
|
||||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||||
)
|
)
|
||||||
|
|
||||||
# channelwise
|
# channelwise
|
||||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||||
self.channelwise = nn.Sequential(
|
self.channelwise = nn.Sequential(
|
||||||
nn.Linear(c, c_hidden),
|
ops.Linear(c, c_hidden),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(c_hidden, c),
|
ops.Linear(c_hidden, c),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||||
@ -171,16 +175,16 @@ class StageA(nn.Module):
|
|||||||
# Encoder blocks
|
# Encoder blocks
|
||||||
self.in_block = nn.Sequential(
|
self.in_block = nn.Sequential(
|
||||||
nn.PixelUnshuffle(2),
|
nn.PixelUnshuffle(2),
|
||||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||||
)
|
)
|
||||||
down_blocks = []
|
down_blocks = []
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||||
down_blocks.append(block)
|
down_blocks.append(block)
|
||||||
down_blocks.append(nn.Sequential(
|
down_blocks.append(nn.Sequential(
|
||||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||||
))
|
))
|
||||||
self.down_blocks = nn.Sequential(*down_blocks)
|
self.down_blocks = nn.Sequential(*down_blocks)
|
||||||
@ -191,7 +195,7 @@ class StageA(nn.Module):
|
|||||||
|
|
||||||
# Decoder blocks
|
# Decoder blocks
|
||||||
up_blocks = [nn.Sequential(
|
up_blocks = [nn.Sequential(
|
||||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||||
)]
|
)]
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||||
@ -199,11 +203,11 @@ class StageA(nn.Module):
|
|||||||
up_blocks.append(block)
|
up_blocks.append(block)
|
||||||
if i < levels - 1:
|
if i < levels - 1:
|
||||||
up_blocks.append(
|
up_blocks.append(
|
||||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||||
padding=1))
|
padding=1))
|
||||||
self.up_blocks = nn.Sequential(*up_blocks)
|
self.up_blocks = nn.Sequential(*up_blocks)
|
||||||
self.out_block = nn.Sequential(
|
self.out_block = nn.Sequential(
|
||||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||||
nn.PixelShuffle(2),
|
nn.PixelShuffle(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
d = max(depth - 3, 3)
|
d = max(depth - 3, 3)
|
||||||
layers = [
|
layers = [
|
||||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
]
|
]
|
||||||
for i in range(depth - 1):
|
for i in range(depth - 1):
|
||||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||||
layers.append(nn.InstanceNorm2d(c_out))
|
layers.append(nn.InstanceNorm2d(c_out))
|
||||||
layers.append(nn.LeakyReLU(0.2))
|
layers.append(nn.LeakyReLU(0.2))
|
||||||
self.encoder = nn.Sequential(*layers)
|
self.encoder = nn.Sequential(*layers)
|
||||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||||
self.logits = nn.Sigmoid()
|
self.logits = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x, cond=None):
|
def forward(self, x, cond=None):
|
||||||
|
@ -19,6 +19,9 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
# EfficientNet
|
# EfficientNet
|
||||||
class EfficientNetEncoder(nn.Module):
|
class EfficientNetEncoder(nn.Module):
|
||||||
@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||||
self.mapper = nn.Sequential(
|
self.mapper = nn.Sequential(
|
||||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||||
)
|
)
|
||||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||||
@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x * 0.5 + 0.5
|
x = x * 0.5 + 0.5
|
||||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||||
o = self.mapper(self.backbone(x))
|
o = self.mapper(self.backbone(x))
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
|||||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = nn.Sequential(
|
self.blocks = nn.Sequential(
|
||||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden),
|
nn.BatchNorm2d(c_hidden),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden),
|
nn.BatchNorm2d(c_hidden),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 2),
|
nn.BatchNorm2d(c_hidden // 2),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 2),
|
nn.BatchNorm2d(c_hidden // 2),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -581,7 +581,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
loaded_memory = loaded_model.model_loaded_memory()
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
|
Loading…
Reference in New Issue
Block a user