diff --git a/comfy/sd.py b/comfy/sd.py index fe60205d..39f88fad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -318,6 +318,37 @@ class VAE: pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples + def decode_tiled(self, samples): + tile_x = tile_y = 64 + overlap = 8 + model_management.unload_model() + output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu") + self.first_stage_model = self.first_stage_model.to(self.device) + for b in range(samples.shape[0]): + s = samples[b:b+1] + out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu") + out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu") + for y in range(0, s.shape[2], tile_y - overlap): + for x in range(0, s.shape[3], tile_x - overlap): + s_in = s[:,:,y:y+tile_y,x:x+tile_x] + + pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device)) + pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) + ps = pixel_samples.cpu() + mask = torch.ones_like(ps) + feather = overlap * 8 + for t in range(feather): + mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) + mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) + mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) + mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) + out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask + out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask + + output[b:b+1] = out/out_div + self.first_stage_model = self.first_stage_model.cpu() + return output.movedim(1,-1) + def encode(self, pixel_samples): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) diff --git a/nodes.py b/nodes.py index 88fe664a..38d0ce6e 100644 --- a/nodes.py +++ b/nodes.py @@ -106,6 +106,21 @@ class VAEDecode: def decode(self, vae, samples): return (vae.decode(samples["samples"]), ) +class VAEDecodeTiled: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "decode" + + CATEGORY = "_for_testing" + + def decode(self, vae, samples): + return (vae.decode_tiled(samples["samples"]), ) + class VAEEncode: def __init__(self, device="cpu"): self.device = device @@ -789,6 +804,7 @@ NODE_CLASS_MAPPINGS = { "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, + "VAEDecodeTiled": VAEDecodeTiled, } CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")