Support loading the Stable Cascade effnet and previewer as a VAE.

The effnet can be used to encode images for img2img with Stage C.
This commit is contained in:
comfyanonymous 2024-02-19 04:06:49 -05:00
parent 2e4628ac8d
commit 3b2e579926
2 changed files with 42 additions and 20 deletions

View File

@ -4,6 +4,7 @@ from enum import Enum
from comfy import model_management from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
import yaml import yaml
@ -158,6 +159,7 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = 4
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@ -176,11 +178,31 @@ class VAE:
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA() self.first_stage_model = StageA()
self.downscale_ratio = 4 self.downscale_ratio = 4
self.upscale_ratio = 4
#TODO #TODO
#self.memory_used_encode #self.memory_used_encode
#self.memory_used_decode #self.memory_used_decode
self.process_input = lambda image: image self.process_input = lambda image: image
self.process_output = lambda image: image self.process_output = lambda image: image
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
else: else:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@ -188,6 +210,7 @@ class VAE:
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4] ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4 self.downscale_ratio = 4
self.upscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
@ -213,6 +236,15 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def vae_encode_crop_pixels(self, pixels):
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -221,9 +253,9 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 3.0)
return output return output
@ -248,7 +280,7 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
@ -265,6 +297,7 @@ class VAE:
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@ -284,6 +317,7 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)

View File

@ -309,18 +309,7 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
@staticmethod
def vae_encode_crop_pixels(pixels):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels): def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
@ -336,7 +325,6 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size): def encode(self, vae, pixels, tile_size):
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, ) return ({"samples":t}, )
@ -350,14 +338,14 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 8) * 8 x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
y = (pixels.shape[2] // 8) * 8 y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2 x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
y_offset = (pixels.shape[2] % 8) // 2 y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]