mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support loading the Stable Cascade effnet and previewer as a VAE.
The effnet can be used to encode images for img2img with Stage C.
This commit is contained in:
parent
2e4628ac8d
commit
3b2e579926
42
comfy/sd.py
42
comfy/sd.py
@ -4,6 +4,7 @@ from enum import Enum
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -158,6 +159,7 @@ class VAE:
|
|||||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@ -176,11 +178,31 @@ class VAE:
|
|||||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||||
self.first_stage_model = StageA()
|
self.first_stage_model = StageA()
|
||||||
self.downscale_ratio = 4
|
self.downscale_ratio = 4
|
||||||
|
self.upscale_ratio = 4
|
||||||
#TODO
|
#TODO
|
||||||
#self.memory_used_encode
|
#self.memory_used_encode
|
||||||
#self.memory_used_decode
|
#self.memory_used_decode
|
||||||
self.process_input = lambda image: image
|
self.process_input = lambda image: image
|
||||||
self.process_output = lambda image: image
|
self.process_output = lambda image: image
|
||||||
|
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
|
||||||
|
self.first_stage_model = StageC_coder()
|
||||||
|
self.downscale_ratio = 32
|
||||||
|
self.latent_channels = 16
|
||||||
|
new_sd = {}
|
||||||
|
for k in sd:
|
||||||
|
new_sd["encoder.{}".format(k)] = sd[k]
|
||||||
|
sd = new_sd
|
||||||
|
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||||
|
self.first_stage_model = StageC_coder()
|
||||||
|
self.latent_channels = 16
|
||||||
|
new_sd = {}
|
||||||
|
for k in sd:
|
||||||
|
new_sd["previewer.{}".format(k)] = sd[k]
|
||||||
|
sd = new_sd
|
||||||
|
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||||
|
self.first_stage_model = StageC_coder()
|
||||||
|
self.downscale_ratio = 32
|
||||||
|
self.latent_channels = 16
|
||||||
else:
|
else:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@ -188,6 +210,7 @@ class VAE:
|
|||||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||||
ddconfig['ch_mult'] = [1, 2, 4]
|
ddconfig['ch_mult'] = [1, 2, 4]
|
||||||
self.downscale_ratio = 4
|
self.downscale_ratio = 4
|
||||||
|
self.upscale_ratio = 4
|
||||||
|
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||||
else:
|
else:
|
||||||
@ -213,6 +236,15 @@ class VAE:
|
|||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
|
|
||||||
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
|
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
||||||
|
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
||||||
|
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||||
|
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
|
||||||
|
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
|
||||||
|
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||||
|
return pixels
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
@ -221,9 +253,9 @@ class VAE:
|
|||||||
|
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
output = self.process_output(
|
output = self.process_output(
|
||||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||||
/ 3.0)
|
/ 3.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -248,7 +280,7 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
@ -265,6 +297,7 @@ class VAE:
|
|||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
@ -284,6 +317,7 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
model_management.load_model_gpu(self.patcher)
|
model_management.load_model_gpu(self.patcher)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||||
|
20
nodes.py
20
nodes.py
@ -309,18 +309,7 @@ class VAEEncode:
|
|||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def vae_encode_crop_pixels(pixels):
|
|
||||||
x = (pixels.shape[1] // 8) * 8
|
|
||||||
y = (pixels.shape[2] // 8) * 8
|
|
||||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
|
||||||
x_offset = (pixels.shape[1] % 8) // 2
|
|
||||||
y_offset = (pixels.shape[2] % 8) // 2
|
|
||||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
|
||||||
return pixels
|
|
||||||
|
|
||||||
def encode(self, vae, pixels):
|
def encode(self, vae, pixels):
|
||||||
pixels = self.vae_encode_crop_pixels(pixels)
|
|
||||||
t = vae.encode(pixels[:,:,:,:3])
|
t = vae.encode(pixels[:,:,:,:3])
|
||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
@ -336,7 +325,6 @@ class VAEEncodeTiled:
|
|||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def encode(self, vae, pixels, tile_size):
|
def encode(self, vae, pixels, tile_size):
|
||||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
|
||||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
@ -350,14 +338,14 @@ class VAEEncodeForInpaint:
|
|||||||
CATEGORY = "latent/inpaint"
|
CATEGORY = "latent/inpaint"
|
||||||
|
|
||||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||||
x = (pixels.shape[1] // 8) * 8
|
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
|
||||||
y = (pixels.shape[2] // 8) * 8
|
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
|
||||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||||
|
|
||||||
pixels = pixels.clone()
|
pixels = pixels.clone()
|
||||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||||
x_offset = (pixels.shape[1] % 8) // 2
|
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
|
||||||
y_offset = (pixels.shape[2] % 8) // 2
|
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
|
||||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user