Fix VALIDATE_INPUTS getting called multiple times.

Allow VALIDATE_INPUTS to only validate specific inputs.
This commit is contained in:
comfyanonymous 2023-12-29 17:33:30 -05:00
parent 12e822c6c8
commit 04b713dda1
2 changed files with 36 additions and 27 deletions

View File

@ -7,6 +7,7 @@ import threading
import heapq
import traceback
import gc
import inspect
import torch
import nodes
@ -402,6 +403,10 @@ def validate_inputs(prompt, item, validated):
errors = []
valid = True
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs:
if x not in inputs:
error = {
@ -531,29 +536,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@ -580,6 +563,35 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:

View File

@ -1491,13 +1491,10 @@ class LoadImageMask:
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image, channel):
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True
class ImageScale: