Add IS_CHANGED method to nodes to check if nodes should be executed again.

LoadImage.IS_CHANGED returns the hash of the image so it will execute again
if the image changed on the disk.
This commit is contained in:
comfyanonymous 2023-01-22 21:42:22 -05:00
parent 15f8da2849
commit 9baa48cb33
2 changed files with 51 additions and 25 deletions

67
main.py
View File

@ -10,13 +10,35 @@ import torch
import nodes import nodes
def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = prompt
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
return input_data_all
def recursive_execute(prompt, outputs, current_item, extra_data={}): def recursive_execute(prompt, outputs, current_item, extra_data={}):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
c_obj = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
valid_inputs = c_obj.INPUT_TYPES()
if unique_id in outputs: if unique_id in outputs:
return [] return []
@ -31,28 +53,8 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}):
if input_unique_id not in outputs: if input_unique_id not in outputs:
executed += recursive_execute(prompt, outputs, input_unique_id, extra_data) executed += recursive_execute(prompt, outputs, input_unique_id, extra_data)
input_data_all = {} input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data)
for x in inputs: obj = class_def()
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data
obj = c_obj()
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = prompt
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
return executed + [unique_id] return executed + [unique_id]
@ -61,12 +63,27 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_changed_old = ''
is_changed = ''
if hasattr(class_def, 'IS_CHANGED'):
if 'is_changed' not in prompt[unique_id]:
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
input_data_all = get_input_data(inputs, class_def)
is_changed = class_def.IS_CHANGED(**input_data_all)
prompt[unique_id]['is_changed'] = is_changed
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs: if unique_id not in outputs:
return True return True
to_delete = False to_delete = False
if unique_id not in old_prompt: if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True to_delete = True
elif inputs == old_prompt[unique_id]['inputs']: elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs: for x in inputs:

View File

@ -3,6 +3,7 @@ import torch
import os import os
import sys import sys
import json import json
import hashlib
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -226,6 +227,14 @@ class LoadImage:
image = torch.from_numpy(image[None])[None,] image = torch.from_numpy(image[None])[None,]
return image return image
@classmethod
def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {