diff --git a/main.py b/main.py index 9983859b..8a4b761c 100644 --- a/main.py +++ b/main.py @@ -59,6 +59,23 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}): outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) return executed + [unique_id] +def recursive_will_execute(prompt, outputs, current_item): + unique_id = current_item + inputs = prompt[unique_id]['inputs'] + will_execute = [] + if unique_id in outputs: + return [] + + for x in inputs: + input_data = inputs[x] + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id not in outputs: + will_execute += recursive_will_execute(prompt, outputs, input_unique_id) + + return will_execute + [unique_id] + def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -120,7 +137,16 @@ class PromptExecutor: current_outputs = set(self.outputs.keys()) executed = [] try: + to_execute = [] for x in prompt: + class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] + if hasattr(class_, 'OUTPUT_NODE'): + to_execute += [(0, x)] + + while len(to_execute) > 0: + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + x = to_execute.pop(0)[-1] + class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] if hasattr(class_, 'OUTPUT_NODE'): if class_.OUTPUT_NODE == True: @@ -132,6 +158,7 @@ class PromptExecutor: valid = False if valid: executed += recursive_execute(prompt, self.outputs, x, extra_data) + except Exception as e: print(traceback.format_exc()) to_delete = []