diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py new file mode 100644 index 00000000..cec105fc --- /dev/null +++ b/comfy_execution/validation.py @@ -0,0 +1,39 @@ +from __future__ import annotations + + +def validate_node_input( + received_type: str, input_type: str, strict: bool = False +) -> bool: + """ + received_type and input_type are both strings of the form "T1,T2,...". + + If strict is True, the input_type must contain the received_type. + For example, if received_type is "STRING" and input_type is "STRING,INT", + this will return True. But if received_type is "STRING,INT" and input_type is + "INT", this will return False. + + If strict is False, the input_type must have overlap with the received_type. + For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT", + this will return True. + + Supports pre-union type extension behaviour of ``__ne__`` overrides. + """ + # If the types are exactly the same, we can return immediately + # Use pre-union behaviour: inverse of `__ne__` + if not received_type != input_type: + return True + + # Not equal, and not strings + if not isinstance(received_type, str) or not isinstance(input_type, str): + return False + + # Split the type strings into sets for comparison + received_types = set(t.strip() for t in received_type.split(",")) + input_types = set(t.strip() for t in input_type.split(",")) + + if strict: + # In strict mode, all received types must be in the input types + return received_types.issubset(input_types) + else: + # In non-strict mode, there must be at least one type in common + return len(received_types.intersection(input_types)) > 0 diff --git a/execution.py b/execution.py index 768e35ab..929ef85f 100644 --- a/execution.py +++ b/execution.py @@ -16,6 +16,7 @@ import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.validation import validate_node_input from comfy.cli_args import args class ExecutionResult(Enum): @@ -527,7 +528,6 @@ class PromptExecutor: comfy.model_management.unload_all_models() - def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and received_type != type_input: - details = f"{x}, {received_type} != {type_input}" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): + details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", diff --git a/tests-unit/execution_test/validate_node_input_test.py b/tests-unit/execution_test/validate_node_input_test.py new file mode 100644 index 00000000..85a0c960 --- /dev/null +++ b/tests-unit/execution_test/validate_node_input_test.py @@ -0,0 +1,119 @@ +import pytest +from comfy_execution.validation import validate_node_input + + +def test_exact_match(): + """Test cases where types match exactly""" + assert validate_node_input("STRING", "STRING") + assert validate_node_input("STRING,INT", "STRING,INT") + assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter + + +def test_strict_mode(): + """Test strict mode validation""" + # Should pass - received type is subset of input type + assert validate_node_input("STRING", "STRING,INT", strict=True) + assert validate_node_input("INT", "STRING,INT", strict=True) + assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True) + + # Should fail - received type is not subset of input type + assert not validate_node_input("STRING,INT", "STRING", strict=True) + assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True) + assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True) + + +def test_non_strict_mode(): + """Test non-strict mode validation (default behavior)""" + # Should pass - types have overlap + assert validate_node_input("STRING,BOOLEAN", "STRING,INT") + assert validate_node_input("STRING,INT", "INT,BOOLEAN") + assert validate_node_input("STRING", "STRING,INT") + + # Should fail - no overlap in types + assert not validate_node_input("BOOLEAN", "STRING,INT") + assert not validate_node_input("FLOAT", "STRING,INT") + assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT") + + +def test_whitespace_handling(): + """Test that whitespace is handled correctly""" + assert validate_node_input("STRING, INT", "STRING,INT") + assert validate_node_input("STRING,INT", "STRING, INT") + assert validate_node_input(" STRING , INT ", "STRING,INT") + assert validate_node_input("STRING,INT", " STRING , INT ") + + +def test_empty_strings(): + """Test behavior with empty strings""" + assert validate_node_input("", "") + assert not validate_node_input("STRING", "") + assert not validate_node_input("", "STRING") + + +def test_single_vs_multiple(): + """Test single type against multiple types""" + assert validate_node_input("STRING", "STRING,INT,BOOLEAN") + assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False) + assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True) + + +def test_non_string(): + """Test non-string types""" + obj1 = object() + obj2 = object() + assert validate_node_input(obj1, obj1) + assert not validate_node_input(obj1, obj2) + + +class NotEqualsOverrideTest(str): + """Test class for ``__ne__`` override.""" + + def __ne__(self, value: object) -> bool: + if self == "*" or value == "*": + return False + if self == "LONGER_THAN_2": + return not len(value) > 2 + raise TypeError("This is a class for unit tests only.") + + +def test_ne_override(): + """Test ``__ne__`` any override""" + any = NotEqualsOverrideTest("*") + invalid_type = "INVALID_TYPE" + obj = object() + assert validate_node_input(any, any) + assert validate_node_input(any, invalid_type) + assert validate_node_input(any, obj) + assert validate_node_input(any, {}) + assert validate_node_input(any, []) + assert validate_node_input(any, [1, 2, 3]) + + +def test_ne_custom_override(): + """Test ``__ne__`` custom override""" + special = NotEqualsOverrideTest("LONGER_THAN_2") + + assert validate_node_input(special, special) + assert validate_node_input(special, "*") + assert validate_node_input(special, "INVALID_TYPE") + assert validate_node_input(special, [1, 2, 3]) + + # Should fail + assert not validate_node_input(special, [1, 2]) + assert not validate_node_input(special, "TY") + + +@pytest.mark.parametrize( + "received,input_type,strict,expected", + [ + ("STRING", "STRING", False, True), + ("STRING,INT", "STRING,INT", False, True), + ("STRING", "STRING,INT", True, True), + ("STRING,INT", "STRING", True, False), + ("BOOLEAN", "STRING,INT", False, False), + ("STRING,BOOLEAN", "STRING,INT", False, True), + ], +) +def test_parametrized_cases(received, input_type, strict, expected): + """Parametrized test cases for various scenarios""" + assert validate_node_input(received, input_type, strict) == expected