From 8d4e06324fb6c477c1f7f409c857c33d0b3b0ce2 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 3 Dec 2024 02:46:00 -0800 Subject: [PATCH] Add union link connection type support (#5806) * Add union type support * Move code * nit --- comfy_execution/validation.py | 32 ++++++++ execution.py | 6 +- .../validate_node_input_test.py | 75 +++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 comfy_execution/validation.py create mode 100644 tests-unit/execution_test/validate_node_input_test.py diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py new file mode 100644 index 00000000..43fb6426 --- /dev/null +++ b/comfy_execution/validation.py @@ -0,0 +1,32 @@ +from __future__ import annotations + + +def validate_node_input( + received_type: str, input_type: str, strict: bool = False +) -> bool: + """ + received_type and input_type are both strings of the form "T1,T2,...". + + If strict is True, the input_type must contain the received_type. + For example, if received_type is "STRING" and input_type is "STRING,INT", + this will return True. But if received_type is "STRING,INT" and input_type is + "INT", this will return False. + + If strict is False, the input_type must have overlap with the received_type. + For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT", + this will return True. + """ + # If the types are exactly the same, we can return immediately + if received_type == input_type: + return True + + # Split the type strings into sets for comparison + received_types = set(t.strip() for t in received_type.split(",")) + input_types = set(t.strip() for t in input_type.split(",")) + + if strict: + # In strict mode, all received types must be in the input types + return received_types.issubset(input_types) + else: + # In non-strict mode, there must be at least one type in common + return len(received_types.intersection(input_types)) > 0 diff --git a/execution.py b/execution.py index 768e35ab..929ef85f 100644 --- a/execution.py +++ b/execution.py @@ -16,6 +16,7 @@ import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.validation import validate_node_input from comfy.cli_args import args class ExecutionResult(Enum): @@ -527,7 +528,6 @@ class PromptExecutor: comfy.model_management.unload_all_models() - def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and received_type != type_input: - details = f"{x}, {received_type} != {type_input}" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): + details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", diff --git a/tests-unit/execution_test/validate_node_input_test.py b/tests-unit/execution_test/validate_node_input_test.py new file mode 100644 index 00000000..d6605e97 --- /dev/null +++ b/tests-unit/execution_test/validate_node_input_test.py @@ -0,0 +1,75 @@ +import pytest +from comfy_execution.validation import validate_node_input + + +def test_exact_match(): + """Test cases where types match exactly""" + assert validate_node_input("STRING", "STRING") + assert validate_node_input("STRING,INT", "STRING,INT") + assert ( + validate_node_input("INT,STRING", "STRING,INT") + ) # Order shouldn't matter + + +def test_strict_mode(): + """Test strict mode validation""" + # Should pass - received type is subset of input type + assert validate_node_input("STRING", "STRING,INT", strict=True) + assert validate_node_input("INT", "STRING,INT", strict=True) + assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True) + + # Should fail - received type is not subset of input type + assert not validate_node_input("STRING,INT", "STRING", strict=True) + assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True) + assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True) + + +def test_non_strict_mode(): + """Test non-strict mode validation (default behavior)""" + # Should pass - types have overlap + assert validate_node_input("STRING,BOOLEAN", "STRING,INT") + assert validate_node_input("STRING,INT", "INT,BOOLEAN") + assert validate_node_input("STRING", "STRING,INT") + + # Should fail - no overlap in types + assert not validate_node_input("BOOLEAN", "STRING,INT") + assert not validate_node_input("FLOAT", "STRING,INT") + assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT") + + +def test_whitespace_handling(): + """Test that whitespace is handled correctly""" + assert validate_node_input("STRING, INT", "STRING,INT") + assert validate_node_input("STRING,INT", "STRING, INT") + assert validate_node_input(" STRING , INT ", "STRING,INT") + assert validate_node_input("STRING,INT", " STRING , INT ") + + +def test_empty_strings(): + """Test behavior with empty strings""" + assert validate_node_input("", "") + assert not validate_node_input("STRING", "") + assert not validate_node_input("", "STRING") + + +def test_single_vs_multiple(): + """Test single type against multiple types""" + assert validate_node_input("STRING", "STRING,INT,BOOLEAN") + assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False) + assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True) + + +@pytest.mark.parametrize( + "received,input_type,strict,expected", + [ + ("STRING", "STRING", False, True), + ("STRING,INT", "STRING,INT", False, True), + ("STRING", "STRING,INT", True, True), + ("STRING,INT", "STRING", True, False), + ("BOOLEAN", "STRING,INT", False, False), + ("STRING,BOOLEAN", "STRING,INT", False, True), + ], +) +def test_parametrized_cases(received, input_type, strict, expected): + """Parametrized test cases for various scenarios""" + assert validate_node_input(received, input_type, strict) == expected