mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
20447e9ec9
There were two bugs in PorterDuffImageComposite.
The first one is the fact that it uses the mask input directly as alpha, missing the conversion (`1-a`). The fix is similar to c16f5744
.
The second one is that all color composition formulas assume alpha premultiplied values, while the input is not premultiplied.
This change fixes both of these issue.
216 lines
8.1 KiB
Python
216 lines
8.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import comfy.utils
|
|
from enum import Enum
|
|
|
|
def resize_mask(mask, shape):
|
|
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
|
|
|
class PorterDuffMode(Enum):
|
|
ADD = 0
|
|
CLEAR = 1
|
|
DARKEN = 2
|
|
DST = 3
|
|
DST_ATOP = 4
|
|
DST_IN = 5
|
|
DST_OUT = 6
|
|
DST_OVER = 7
|
|
LIGHTEN = 8
|
|
MULTIPLY = 9
|
|
OVERLAY = 10
|
|
SCREEN = 11
|
|
SRC = 12
|
|
SRC_ATOP = 13
|
|
SRC_IN = 14
|
|
SRC_OUT = 15
|
|
SRC_OVER = 16
|
|
XOR = 17
|
|
|
|
|
|
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
|
# convert mask to alpha
|
|
src_alpha = 1 - src_alpha
|
|
dst_alpha = 1 - dst_alpha
|
|
# premultiply alpha
|
|
src_image = src_image * src_alpha
|
|
dst_image = dst_image * dst_alpha
|
|
|
|
# composite ops below assume alpha-premultiplied images
|
|
if mode == PorterDuffMode.ADD:
|
|
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
|
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
|
elif mode == PorterDuffMode.CLEAR:
|
|
out_alpha = torch.zeros_like(dst_alpha)
|
|
out_image = torch.zeros_like(dst_image)
|
|
elif mode == PorterDuffMode.DARKEN:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
|
elif mode == PorterDuffMode.DST:
|
|
out_alpha = dst_alpha
|
|
out_image = dst_image
|
|
elif mode == PorterDuffMode.DST_ATOP:
|
|
out_alpha = src_alpha
|
|
out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
|
|
elif mode == PorterDuffMode.DST_IN:
|
|
out_alpha = src_alpha * dst_alpha
|
|
out_image = dst_image * src_alpha
|
|
elif mode == PorterDuffMode.DST_OUT:
|
|
out_alpha = (1 - src_alpha) * dst_alpha
|
|
out_image = (1 - src_alpha) * dst_image
|
|
elif mode == PorterDuffMode.DST_OVER:
|
|
out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
|
|
out_image = dst_image + (1 - dst_alpha) * src_image
|
|
elif mode == PorterDuffMode.LIGHTEN:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
|
|
elif mode == PorterDuffMode.MULTIPLY:
|
|
out_alpha = src_alpha * dst_alpha
|
|
out_image = src_image * dst_image
|
|
elif mode == PorterDuffMode.OVERLAY:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
|
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
|
elif mode == PorterDuffMode.SCREEN:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = src_image + dst_image - src_image * dst_image
|
|
elif mode == PorterDuffMode.SRC:
|
|
out_alpha = src_alpha
|
|
out_image = src_image
|
|
elif mode == PorterDuffMode.SRC_ATOP:
|
|
out_alpha = dst_alpha
|
|
out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
|
|
elif mode == PorterDuffMode.SRC_IN:
|
|
out_alpha = src_alpha * dst_alpha
|
|
out_image = src_image * dst_alpha
|
|
elif mode == PorterDuffMode.SRC_OUT:
|
|
out_alpha = (1 - dst_alpha) * src_alpha
|
|
out_image = (1 - dst_alpha) * src_image
|
|
elif mode == PorterDuffMode.SRC_OVER:
|
|
out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
|
|
out_image = src_image + (1 - src_alpha) * dst_image
|
|
elif mode == PorterDuffMode.XOR:
|
|
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
|
else:
|
|
return None, None
|
|
|
|
# back to non-premultiplied alpha
|
|
out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
|
|
out_image = torch.clamp(out_image, 0, 1)
|
|
# convert alpha to mask
|
|
out_alpha = 1 - out_alpha
|
|
return out_image, out_alpha
|
|
|
|
|
|
class PorterDuffImageComposite:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"source": ("IMAGE",),
|
|
"source_alpha": ("MASK",),
|
|
"destination": ("IMAGE",),
|
|
"destination_alpha": ("MASK",),
|
|
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK")
|
|
FUNCTION = "composite"
|
|
CATEGORY = "mask/compositing"
|
|
|
|
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
|
out_images = []
|
|
out_alphas = []
|
|
|
|
for i in range(batch_size):
|
|
src_image = source[i]
|
|
dst_image = destination[i]
|
|
|
|
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
|
|
|
src_alpha = source_alpha[i].unsqueeze(2)
|
|
dst_alpha = destination_alpha[i].unsqueeze(2)
|
|
|
|
if dst_alpha.shape[:2] != dst_image.shape[:2]:
|
|
upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
|
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
|
if src_image.shape != dst_image.shape:
|
|
upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
|
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
|
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
|
if src_alpha.shape != dst_alpha.shape:
|
|
upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
|
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
|
|
|
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
|
|
|
out_images.append(out_image)
|
|
out_alphas.append(out_alpha.squeeze(2))
|
|
|
|
result = (torch.stack(out_images), torch.stack(out_alphas))
|
|
return result
|
|
|
|
|
|
class SplitImageWithAlpha:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask/compositing"
|
|
RETURN_TYPES = ("IMAGE", "MASK")
|
|
FUNCTION = "split_image_with_alpha"
|
|
|
|
def split_image_with_alpha(self, image: torch.Tensor):
|
|
out_images = [i[:,:,:3] for i in image]
|
|
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
|
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
|
return result
|
|
|
|
|
|
class JoinImageWithAlpha:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"alpha": ("MASK",),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask/compositing"
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "join_image_with_alpha"
|
|
|
|
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
|
batch_size = min(len(image), len(alpha))
|
|
out_images = []
|
|
|
|
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
|
for i in range(batch_size):
|
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
|
|
|
result = (torch.stack(out_images),)
|
|
return result
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"PorterDuffImageComposite": PorterDuffImageComposite,
|
|
"SplitImageWithAlpha": SplitImageWithAlpha,
|
|
"JoinImageWithAlpha": JoinImageWithAlpha,
|
|
}
|
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
|
"SplitImageWithAlpha": "Split Image with Alpha",
|
|
"JoinImageWithAlpha": "Join Image with Alpha",
|
|
}
|