Add a basic ImageScale node.

It's pretty much the same as the LatentUpscale node for now but for images
in pixel space.
This commit is contained in:
comfyanonymous 2023-02-04 15:53:29 -05:00
parent bff0e11941
commit 4225d1cb9f

View File

@ -186,6 +186,23 @@ class EmptyLatentImage:
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (latent, )
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@ -202,21 +219,7 @@ class LatentUpscale:
CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
s = torch.nn.functional.interpolate(s, size=(height // 8, width // 8), mode=upscale_method)
s = common_upscale(samples, width // 8, height // 8, upscale_method, crop)
return (s,)
class LatentRotate:
@ -505,7 +508,26 @@ class LoadImage:
m.update(f.read())
return m.digest().hex()
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"crop": (s.crop_methods,)}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image"
def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1)
s = common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
return (s,)
NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
@ -518,6 +540,7 @@ NODE_CLASS_MAPPINGS = {
"LatentUpscale": LatentUpscale,
"SaveImage": SaveImage,
"LoadImage": LoadImage,
"ImageScale": ImageScale,
"ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea,
"KSamplerAdvanced": KSamplerAdvanced,