From a7b5eaa7e3d5ecc1ae1d395e50d781124bb4e611 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 04:25:46 -0500 Subject: [PATCH] Forgot to commit this. --- comfy/ldm/cascade/stage_c_coder.py | 96 ++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 comfy/ldm/cascade/stage_c_coder.py diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py new file mode 100644 index 00000000..98c9a0b6 --- /dev/null +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -0,0 +1,96 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import torch +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) + self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) + + def forward(self, x): + x = x * 0.5 + 0.5 + x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + o = self.mapper(self.backbone(x)) + print(o.shape) + return o + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return (self.blocks(x) - 0.5) * 2.0 + +class StageC_coder(nn.Module): + def __init__(self): + super().__init__() + self.previewer = Previewer() + self.encoder = EfficientNetEncoder() + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.previewer(x)