diff --git a/execution.py b/execution.py index 7ad17131..53ba2e0f 100644 --- a/execution.py +++ b/execution.py @@ -7,6 +7,7 @@ import threading import heapq import traceback import gc +import inspect import torch import nodes @@ -402,6 +403,10 @@ def validate_inputs(prompt, item, validated): errors = [] valid = True + validate_function_inputs = [] + if hasattr(obj_class, "VALIDATE_INPUTS"): + validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + for x in required_inputs: if x not in inputs: error = { @@ -531,29 +536,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: + if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -580,6 +563,35 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue + if len(validate_function_inputs) > 0: + input_data_all = get_input_data(inputs, obj_class, unique_id) + input_filtered = {} + for x in input_data_all: + if x in validate_function_inputs: + input_filtered[x] = input_data_all[x] + + #ret = obj_class.VALIDATE_INPUTS(**input_filtered) + ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + for x in input_filtered: + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: diff --git a/nodes.py b/nodes.py index 027bf55d..8e3ec947 100644 --- a/nodes.py +++ b/nodes.py @@ -1491,13 +1491,10 @@ class LoadImageMask: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image, channel): + def VALIDATE_INPUTS(s, image): if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) - if channel not in s._color_channels: - return "Invalid color channel: {}".format(channel) - return True class ImageScale: