Persist node instances between executions instead of deleting them.

If the same node id with the same class exists between two executions the
same instance will be used.

This means you can now cache things in nodes for more efficiency.
This commit is contained in:
comfyanonymous 2023-06-29 23:38:56 -04:00
parent 9920367d3c
commit 6e9f28401f

View File

@ -110,7 +110,7 @@ def format_value(x):
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
@ -125,7 +125,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result
@ -136,7 +136,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
@ -256,6 +260,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor:
def __init__(self, server):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.old_prompt = {}
self.server = server
@ -322,6 +327,17 @@ class PromptExecutor:
for o in to_delete:
d = self.outputs.pop(o)
del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
@ -349,7 +365,7 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break