From ffc56c53c9cccfcc21c92fe14cb095bb32ea2744 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:22:38 -0400 Subject: [PATCH] Add a node_errors to the /prompt error json response. "node_errors" contains a dict keyed by node ids. The contents are a message and a list of dependent outputs. --- execution.py | 27 ++++++++++++++++----------- server.py | 4 ++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index 35f04434..212e789c 100644 --- a/execution.py +++ b/execution.py @@ -299,18 +299,18 @@ def validate_inputs(prompt, item, validated): required_inputs = class_inputs['required'] for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) + return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) + return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) r = validate_inputs(prompt, o_id, validated) if r[0] == False: validated[o_id] = r @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) @@ -338,13 +338,13 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for r in ret: if r != True: - return (False, "{}, {}".format(class_type, r)) + return (False, "{}, {}".format(class_type, r), unique_id) else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) - ret = (True, "") + ret = (True, "", unique_id) validated[unique_id] = ret return ret @@ -356,10 +356,11 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs") + return (False, "Prompt has no outputs", [], []) good_outputs = set() errors = [] + node_errors = {} validated = {} for o in outputs: valid = False @@ -368,6 +369,7 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] + node_id = m[2] except Exception as e: print(traceback.format_exc()) valid = False @@ -379,12 +381,15 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) - return (True, "", list(good_outputs)) + return (True, "", list(good_outputs), node_errors) class PromptQueue: diff --git a/server.py b/server.py index 701c0e7a..8429a63f 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) - return web.json_response({"error": valid[1]}, status=400) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt"}, status=400) + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @routes.post("/queue") async def post_queue(request):