diff --git a/nodes.py b/nodes.py index 2c354fd9..578f6c4e 100644 --- a/nodes.py +++ b/nodes.py @@ -115,17 +115,33 @@ class EmptyLatentImage: class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] + crop_methods = ["disabled", "center"] @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),}} + "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), + "crop": (s.crop_methods,)}} RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" - def upscale(self, samples, upscale_method, width, height): - s = torch.nn.functional.interpolate(samples, size=(height // 8, width // 8), mode=upscale_method) + def upscale(self, samples, upscale_method, width, height, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + s = torch.nn.functional.interpolate(s, size=(height // 8, width // 8), mode=upscale_method) return (s,) class KSampler: