diff --git a/comfy/sd.py b/comfy/sd.py index d8c0bfa7..00633e10 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,6 +4,7 @@ from enum import Enum from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA +from .ldm.cascade.stage_c_coder import StageC_coder import yaml @@ -158,6 +159,7 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 + self.upscale_ratio = 8 self.latent_channels = 4 self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) @@ -176,11 +178,31 @@ class VAE: elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() self.downscale_ratio = 4 + self.upscale_ratio = 4 #TODO #self.memory_used_encode #self.memory_used_decode self.process_input = lambda image: image self.process_output = lambda image: image + elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["encoder.{}".format(k)] = sd[k] + sd = new_sd + elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["previewer.{}".format(k)] = sd[k] + sd = new_sd + elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -188,6 +210,7 @@ class VAE: if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE ddconfig['ch_mult'] = [1, 2, 4] self.downscale_ratio = 4 + self.upscale_ratio = 4 self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: @@ -213,6 +236,15 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def vae_encode_crop_pixels(self, pixels): + x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio + y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -221,9 +253,9 @@ class VAE: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() output = self.process_output( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) / 3.0) return output @@ -248,7 +280,7 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) @@ -265,6 +297,7 @@ class VAE: return output.movedim(1,-1) def encode(self, pixel_samples): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -284,6 +317,7 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) diff --git a/nodes.py b/nodes.py index 47203f41..a577c212 100644 --- a/nodes.py +++ b/nodes.py @@ -309,18 +309,7 @@ class VAEEncode: CATEGORY = "latent" - @staticmethod - def vae_encode_crop_pixels(pixels): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] - return pixels - def encode(self, vae, pixels): - pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) return ({"samples":t}, ) @@ -336,7 +325,6 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size): - pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) @@ -350,14 +338,14 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]