import os import sys import copy import json import logging import threading import heapq import traceback import gc import inspect from typing import List, Literal, NamedTuple, Optional import torch import nodes import comfy.model_management def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: input_data_all[x] = (None,) continue obj = outputs[input_unique_id][output_index] input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] return input_data_all def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): input_is_list = obj.INPUT_IS_LIST if len(input_data_all) == 0: max_len_input = 0 else: max_len_input = max([len(x) for x in input_data_all.values()]) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): d_new = dict() for k,v in d.items(): d_new[k] = v[i if len(v) > i else -1] return d_new results = [] if input_is_list: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**input_data_all)) elif max_len_input == 0: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)()) else: for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results def get_output_data(obj, input_data_all): results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) for r in return_values: if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) if 'result' in r: results.append(r['result']) else: results.append(r) output = [] if len(results) > 0: # check which outputs need concatenating output_is_list = [False] * len(results[0]) if hasattr(obj, "OUTPUT_IS_LIST"): output_is_list = obj.OUTPUT_IS_LIST # merge node execution results for i, is_list in zip(range(len(results[0])), output_is_list): if is_list: output.append([x for o in results for x in o[i]]) else: output.append([o[i] for o in results]) ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui def format_value(x): if x is None: return None elif isinstance(x, (int, float, bool, str)): return x else: return str(x) def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: return (True, None, None) for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) if result[0] is not True: # Another node failed further upstream return result input_data_all = None try: input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = object_storage.get((unique_id, class_type), None) if obj is None: obj = class_def() object_storage[(unique_id, class_type)] = obj output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs error_details = { "node_id": unique_id, } return (False, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) input_data_formatted = {} if input_data_all is not None: input_data_formatted = {} for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] output_data_formatted = {} for node_id, node_outputs in outputs.items(): output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] logging.error("!!! Exception during processing !!!") logging.error(traceback.format_exc()) error_details = { "node_id": unique_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, "current_outputs": output_data_formatted } return (False, error_details, ex) executed.add(unique_id) return (True, None, None) def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] will_execute = [] if unique_id in outputs: return [] for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: will_execute += recursive_will_execute(prompt, outputs, input_unique_id) return will_execute + [unique_id] def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] is_changed_old = '' is_changed = '' to_delete = False if hasattr(class_def, 'IS_CHANGED'): if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: #is_changed = class_def.IS_CHANGED(**input_data_all) is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True else: is_changed = prompt[unique_id]['is_changed'] if unique_id not in outputs: return True if not to_delete: if is_changed != is_changed_old: to_delete = True elif unique_id not in old_prompt: to_delete = True elif inputs == old_prompt[unique_id]['inputs']: for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id in outputs: to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) else: to_delete = True if to_delete: break else: to_delete = True if to_delete: d = outputs.pop(unique_id) del d return to_delete class PromptExecutor: def __init__(self, server): self.server = server self.reset() def reset(self): self.outputs = {} self.object_storage = {} self.outputs_ui = {} self.status_messages = [] self.success = True self.old_prompt = {} def add_message(self, event, data, broadcast: bool): self.status_messages.append((event, data)) if self.server.client_id is not None or broadcast: self.server.send_sync(event, data, self.server.client_id) def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] # First, send back the status to the frontend depending # on the exception type if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), } self.add_message("execution_interrupted", mes, broadcast=True) else: mes = { "prompt_id": prompt_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } self.add_message("execution_error", mes, broadcast=False) # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): to_delete += [o] if o in self.old_prompt: d = self.old_prompt.pop(o) del d for o in to_delete: d = self.outputs.pop(o) del d def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] else: self.server.client_id = None self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] for o in self.outputs: if o not in prompt: to_delete += [o] for o in to_delete: d = self.outputs.pop(o) del d to_delete = [] for o in self.object_storage: if o[0] not in prompt: to_delete += [o] else: p = prompt[o[0]] if o[1] != p['class_type']: to_delete += [o] for o in to_delete: d = self.object_storage.pop(o) del d for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) for x in list(self.outputs_ui.keys()): if x not in current_outputs: d = self.outputs_ui.pop(x) del d comfy.model_management.cleanup_models() self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) executed = set() output_node_id = None to_execute = [] for node_id in list(execute_outputs): to_execute += [(0, node_id)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) output_node_id = to_execute.pop(0)[-1] # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) if self.success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: return validated[unique_id] inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] errors = [] valid = True validate_function_inputs = [] if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args for x in required_inputs: if x not in inputs: error = { "type": "required_input_missing", "message": "Required input is missing", "details": f"{x}", "extra_info": { "input_name": x } } errors.append(error) continue val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: error = { "type": "bad_linked_input", "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val } } errors.append(error) continue o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: received_type = r[val[1]] details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", "details": details, "extra_info": { "input_name": x, "input_config": info, "received_type": received_type, "linked_node": val } } errors.append(error) continue try: r = validate_inputs(prompt, o_id, validated) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False continue except Exception as ex: typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_inner_validation", "message": "Exception when validating inner node", "details": str(ex), "extra_info": { "input_name": x, "input_config": info, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "linked_node": val } }] validated[o_id] = (False, reasons, o_id) continue else: try: if type_input == "INT": val = int(val) inputs[x] = val if type_input == "FLOAT": val = float(val) inputs[x] = val if type_input == "STRING": val = str(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", "message": f"Failed to convert an input value to a {type_input} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, "exception_message": str(ex) } } errors.append(error) continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: error = { "type": "value_smaller_than_min", "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, } } errors.append(error) continue if "max" in info[1] and val > info[1]["max"]: error = { "type": "value_bigger_than_max", "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, } } errors.append(error) continue if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths if len(type_input) > 20: list_info = f"(list of length {len(type_input)})" input_config = None else: list_info = str(type_input) error = { "type": "value_not_in_list", "message": "Value not in list", "details": f"{x}: '{val}' not in {list_info}", "extra_info": { "input_name": x, "input_config": input_config, "received_value": val, } } errors.append(error) continue if len(validate_function_inputs) > 0: input_data_all = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs: input_filtered[x] = input_data_all[x] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") for x in input_filtered: for i, r in enumerate(ret): if r is not True: details = f"{x}" if r is not False: details += f" - {str(r)}" error = { "type": "custom_validation_failed", "message": "Custom validation failed for node", "details": details, "extra_info": { "input_name": x, "input_config": info, "received_value": val, } } errors.append(error) continue if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: ret = (True, [], unique_id) validated[unique_id] = ret return ret def full_type_name(klass): module = klass.__module__ if module == 'builtins': return klass.__qualname__ return module + '.' + klass.__qualname__ def validate_prompt(prompt): outputs = set() for x in prompt: class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True: outputs.add(x) if len(outputs) == 0: error = { "type": "prompt_no_outputs", "message": "Prompt has no outputs", "details": "", "extra_info": {} } return (False, error, [], []) good_outputs = set() errors = [] node_errors = {} validated = {} for o in outputs: valid = False reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] reasons = m[1] except Exception as ex: typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] validated[o] = (False, reasons, o) if valid is True: good_outputs.add(o) else: logging.error(f"Failed to validate prompt for output {o}:") if len(reasons) > 0: logging.error("* (prompt):") for reason in reasons: logging.error(f" - {reason['message']}: {reason['details']}") errors += [(o, reasons)] for node_id, result in validated.items(): valid = result[0] reasons = result[1] # If a node upstream has errors, the nodes downstream will also # be reported as invalid, but there will be no errors attached. # So don't return those nodes as having errors in the response. if valid is not True and len(reasons) > 0: if node_id not in node_errors: class_type = prompt[node_id]['class_type'] node_errors[node_id] = { "errors": reasons, "dependent_outputs": [], "class_type": class_type } logging.error(f"* {class_type} {node_id}:") for reason in reasons: logging.error(f" - {reason['message']}: {reason['details']}") node_errors[node_id]["dependent_outputs"].append(o) logging.error("Output will be ignored") if len(good_outputs) == 0: errors_list = [] for o, errors in errors: for error in errors: errors_list.append(f"{error['message']}: {error['details']}") errors_list = "\n".join(errors_list) error = { "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, "extra_info": {} } return (False, error, list(good_outputs), node_errors) return (True, None, list(good_outputs), node_errors) MAXIMUM_HISTORY_SIZE = 10000 class PromptQueue: def __init__(self, server): self.server = server self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) self.task_counter = 0 self.queue = [] self.currently_running = {} self.history = {} self.flags = {} server.prompt_queue = self def put(self, item): with self.mutex: heapq.heappush(self.queue, item) self.server.queue_updated() self.not_empty.notify() def get(self, timeout=None): with self.not_empty: while len(self.queue) == 0: self.not_empty.wait(timeout=timeout) if timeout is not None and len(self.queue) == 0: return None item = heapq.heappop(self.queue) i = self.task_counter self.currently_running[i] = copy.deepcopy(item) self.task_counter += 1 self.server.queue_updated() return (item, i) class ExecutionStatus(NamedTuple): status_str: Literal['success', 'error'] completed: bool messages: List[str] def task_done(self, item_id, outputs, status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) status_dict: Optional[dict] = None if status is not None: status_dict = copy.deepcopy(status._asdict()) self.history[prompt[1]] = { "prompt": prompt, "outputs": copy.deepcopy(outputs), 'status': status_dict, } self.server.queue_updated() def get_current_queue(self): with self.mutex: out = [] for x in self.currently_running.values(): out += [x] return (out, copy.deepcopy(self.queue)) def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) def wipe_queue(self): with self.mutex: self.queue = [] self.server.queue_updated() def delete_queue_item(self, function): with self.mutex: for x in range(len(self.queue)): if function(self.queue[x]): if len(self.queue) == 1: self.wipe_queue() else: self.queue.pop(x) heapq.heapify(self.queue) self.server.queue_updated() return True return False def get_history(self, prompt_id=None, max_items=None, offset=-1): with self.mutex: if prompt_id is None: out = {} i = 0 if offset < 0 and max_items is not None: offset = len(self.history) - max_items for k in self.history: if i >= offset: out[k] = self.history[k] if max_items is not None and len(out) >= max_items: break i += 1 return out elif prompt_id in self.history: return {prompt_id: copy.deepcopy(self.history[prompt_id])} else: return {} def wipe_history(self): with self.mutex: self.history = {} def delete_history_item(self, id_to_delete): with self.mutex: self.history.pop(id_to_delete, None) def set_flag(self, name, data): with self.mutex: self.flags[name] = data self.not_empty.notify() def get_flags(self, reset=True): with self.mutex: if reset: ret = self.flags self.flags = {} return ret else: return self.flags.copy()