From 56196ab0f72c8f671bd85b425744f80f02c823ea Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 4 Apr 2023 10:57:34 -0400 Subject: [PATCH] use common_upcale in blend --- comfy_extras/nodes_post_processing.py | 29 +++++---------------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 322f3ca8..703deaab 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -3,6 +3,8 @@ import torch import torch.nn.functional as F from PIL import Image +import comfy.utils + class Blend: def __init__(self): @@ -31,7 +33,9 @@ class Blend: def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): if image1.shape != image2.shape: - image2 = self.crop_and_resize(image2, image1.shape) + image2 = image2.permute(0, 3, 1, 2) + image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') + image2 = image2.permute(0, 2, 3, 1) blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor @@ -55,29 +59,6 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) - def crop_and_resize(self, img: torch.Tensor, target_shape: tuple): - batch_size, img_h, img_w, img_c = img.shape - _, target_h, target_w, _ = target_shape - img_aspect_ratio = img_w / img_h - target_aspect_ratio = target_w / target_h - - # Crop center of the image to the target aspect ratio - if img_aspect_ratio > target_aspect_ratio: - new_width = int(img_h * target_aspect_ratio) - left = (img_w - new_width) // 2 - img = img[:, :, left:left + new_width, :] - else: - new_height = int(img_w / target_aspect_ratio) - top = (img_h - new_height) // 2 - img = img[:, top:top + new_height, :, :] - - # Resize to target size - img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False) - img = img.permute(0, 2, 3, 1) - - return img - class Blur: def __init__(self): pass