From 3d2e3a6f29670809aa97b41505fa4e93ce11b98d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 2 Apr 2025 19:32:34 -0400 Subject: [PATCH] Fix alpha image issue in more nodes. --- comfy_extras/nodes_mask.py | 7 ++----- comfy_extras/nodes_post_processing.py | 3 ++- node_helpers.py | 8 ++++++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index e1f0c822..13d2b4ba 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import torch import comfy.utils +import node_helpers from nodes import MAX_RESOLUTION @@ -87,11 +88,7 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): - if destination.shape[-1] < source.shape[-1]: - source = source[...,:destination.shape[-1]] - elif destination.shape[-1] > source.shape[-1]: - destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = source[..., -1] + destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 68f6ef51..5b954201 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,7 +6,7 @@ import math import comfy.utils import comfy.model_management - +import node_helpers class Blend: def __init__(self): @@ -34,6 +34,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) diff --git a/node_helpers.py b/node_helpers.py index 48da3b09..4f805387 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -44,3 +44,11 @@ def string_to_torch_dtype(string): return torch.float16 if string == "bf16": return torch.bfloat16 + +def image_alpha_fix(destination, source): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = source[..., -1] + return destination, source