mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-06 19:42:08 +08:00

* [feat] Add GetImageSize node to return image dimensions Added a simple GetImageSize node in comfy_extras/nodes_images.py that returns width and height of input images. The node displays dimensions on the UI via PromptServer and provides width/height as outputs for further processing. * add display name mapping * [fix] Add server module mock to unit tests for PromptServer import Updated test to mock server module preventing import errors from the new PromptServer usage in GetImageSize node. Uses direct import pattern consistent with rest of codebase.
244 lines
10 KiB
Python
244 lines
10 KiB
Python
import torch
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# Mock nodes module to prevent CUDA initialization during import
|
|
mock_nodes = MagicMock()
|
|
mock_nodes.MAX_RESOLUTION = 16384
|
|
|
|
# Mock server module for PromptServer
|
|
mock_server = MagicMock()
|
|
|
|
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
|
|
from comfy_extras.nodes_images import ImageStitch
|
|
|
|
|
|
class TestImageStitch:
|
|
|
|
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
|
"""Helper to create test images with specific dimensions"""
|
|
return torch.rand(batch_size, height, width, channels)
|
|
|
|
def test_no_image2_passthrough(self):
|
|
"""Test that when image2 is None, image1 is returned unchanged"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image()
|
|
|
|
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
|
|
|
assert len(result) == 1
|
|
assert torch.equal(result[0], image1)
|
|
|
|
def test_basic_horizontal_stitch_right(self):
|
|
"""Test basic horizontal stitching to the right"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=24)
|
|
|
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
|
|
|
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
|
|
|
def test_basic_horizontal_stitch_left(self):
|
|
"""Test basic horizontal stitching to the left"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=24)
|
|
|
|
result = node.stitch(image1, "left", False, 0, "white", image2)
|
|
|
|
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
|
|
|
def test_basic_vertical_stitch_down(self):
|
|
"""Test basic vertical stitching downward"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=24, width=32)
|
|
|
|
result = node.stitch(image1, "down", False, 0, "white", image2)
|
|
|
|
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
|
|
|
def test_basic_vertical_stitch_up(self):
|
|
"""Test basic vertical stitching upward"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=24, width=32)
|
|
|
|
result = node.stitch(image1, "up", False, 0, "white", image2)
|
|
|
|
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
|
|
|
def test_size_matching_horizontal(self):
|
|
"""Test size matching for horizontal concatenation"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=64, width=64)
|
|
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
|
|
|
result = node.stitch(image1, "right", True, 0, "white", image2)
|
|
|
|
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
|
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
|
assert result[0].shape == (1, 64, expected_width, 3)
|
|
|
|
def test_size_matching_vertical(self):
|
|
"""Test size matching for vertical concatenation"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=64, width=64)
|
|
image2 = self.create_test_image(height=32, width=32)
|
|
|
|
result = node.stitch(image1, "down", True, 0, "white", image2)
|
|
|
|
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
|
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
|
assert result[0].shape == (1, expected_height, 64, 3)
|
|
|
|
def test_padding_for_mismatched_heights_horizontal(self):
|
|
"""Test padding when heights don't match in horizontal concatenation"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=64, width=32)
|
|
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
|
|
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
|
|
|
# Both images should be padded to height 64
|
|
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
|
|
|
def test_padding_for_mismatched_widths_vertical(self):
|
|
"""Test padding when widths don't match in vertical concatenation"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=64)
|
|
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
|
|
|
result = node.stitch(image1, "down", False, 0, "white", image2)
|
|
|
|
# Both images should be padded to width 64
|
|
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
|
|
|
def test_spacing_horizontal(self):
|
|
"""Test spacing addition in horizontal concatenation"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=24)
|
|
spacing_width = 16
|
|
|
|
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
|
|
|
# Expected width: 32 + 16 (spacing) + 24 = 72
|
|
assert result[0].shape == (1, 32, 72, 3)
|
|
|
|
def test_spacing_vertical(self):
|
|
"""Test spacing addition in vertical concatenation"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=24, width=32)
|
|
spacing_width = 16
|
|
|
|
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
|
|
|
# Expected height: 32 + 16 (spacing) + 24 = 72
|
|
assert result[0].shape == (1, 72, 32, 3)
|
|
|
|
def test_spacing_color_values(self):
|
|
"""Test that spacing colors are applied correctly"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=32)
|
|
|
|
# Test white spacing
|
|
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
|
# Check that spacing region contains white values (close to 1.0)
|
|
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
|
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
|
|
|
# Test black spacing
|
|
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
|
spacing_region = result_black[0][:, :, 32:48, :]
|
|
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
|
|
|
def test_odd_spacing_width_made_even(self):
|
|
"""Test that odd spacing widths are made even"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=32)
|
|
|
|
# Use odd spacing width
|
|
result = node.stitch(image1, "right", False, 15, "white", image2)
|
|
|
|
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
|
assert result[0].shape == (1, 32, 80, 3)
|
|
|
|
def test_batch_size_matching(self):
|
|
"""Test that different batch sizes are handled correctly"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
|
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
|
|
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
|
|
|
# Should match larger batch size
|
|
assert result[0].shape == (2, 32, 64, 3)
|
|
|
|
def test_channel_matching_rgb_to_rgba(self):
|
|
"""Test that channel differences are handled (RGB + alpha)"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(channels=3) # RGB
|
|
image2 = self.create_test_image(channels=4) # RGBA
|
|
|
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
|
|
|
# Should have 4 channels (RGBA)
|
|
assert result[0].shape[-1] == 4
|
|
|
|
def test_channel_matching_rgba_to_rgb(self):
|
|
"""Test that channel differences are handled (RGBA + RGB)"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(channels=4) # RGBA
|
|
image2 = self.create_test_image(channels=3) # RGB
|
|
|
|
result = node.stitch(image1, "right", False, 0, "white", image2)
|
|
|
|
# Should have 4 channels (RGBA)
|
|
assert result[0].shape[-1] == 4
|
|
|
|
def test_all_color_options(self):
|
|
"""Test all available color options"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=32)
|
|
|
|
colors = ["white", "black", "red", "green", "blue"]
|
|
|
|
for color in colors:
|
|
result = node.stitch(image1, "right", False, 16, color, image2)
|
|
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
|
|
|
def test_all_directions(self):
|
|
"""Test all direction options"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(height=32, width=32)
|
|
image2 = self.create_test_image(height=32, width=32)
|
|
|
|
directions = ["right", "left", "up", "down"]
|
|
|
|
for direction in directions:
|
|
result = node.stitch(image1, direction, False, 0, "white", image2)
|
|
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
|
|
|
def test_batch_size_channel_spacing_integration(self):
|
|
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
|
node = ImageStitch()
|
|
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
|
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
|
|
|
result = node.stitch(image1, "right", True, 8, "red", image2)
|
|
|
|
# Should handle: batch matching, size matching, channel matching, spacing
|
|
assert result[0].shape[0] == 2 # Batch size matched
|
|
assert result[0].shape[-1] == 4 # Channels matched to max
|
|
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
|
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
|
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
|
expected_total_width = 48 + 8 + expected_image2_width
|
|
assert result[0].shape[2] == expected_total_width
|
|
|