mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
feat:Latent Save/Load (#662)
* wip * latent dir * fix * fix * now working * mark todo * remove server.py changes to separate PRt --------- Co-authored-by: Lt.Dr.Data <lt.dr.data@gmail.com>
This commit is contained in:
parent
4088e61aa6
commit
e7f2816c6f
0
input/latents/_input_latents_will_be_put_here
Normal file
0
input/latents/_input_latents_will_be_put_here
Normal file
90
nodes.py
90
nodes.py
@ -29,6 +29,8 @@ import importlib
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
import safetensors.torch as sft
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
@ -246,6 +248,91 @@ class VAEEncodeForInpaint:
|
|||||||
|
|
||||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLatent:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents")
|
||||||
|
self.type = "output"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT", ),
|
||||||
|
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
def map_filename(filename):
|
||||||
|
prefix_len = len(os.path.basename(filename_prefix))
|
||||||
|
prefix = filename[:prefix_len + 1]
|
||||||
|
try:
|
||||||
|
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||||
|
except:
|
||||||
|
digits = 0
|
||||||
|
return (digits, prefix)
|
||||||
|
|
||||||
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|
||||||
|
full_output_folder = os.path.join(self.output_dir, subfolder)
|
||||||
|
|
||||||
|
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
||||||
|
print("Saving latent outside the 'input/latents' folder is not allowed.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||||
|
except ValueError:
|
||||||
|
counter = 1
|
||||||
|
except FileNotFoundError:
|
||||||
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
# support save metadata for latent sharing
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {"workflow": prompt_info}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
|
sft.save_file(samples, file, metadata=metadata)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLatent:
|
||||||
|
input_dir = os.path.join(folder_paths.get_input_directory(), "latents")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")]
|
||||||
|
return {"required": {"latent": [sorted(files), ]}, }
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, latent):
|
||||||
|
file = folder_paths.get_annotated_filepath(latent, self.input_dir)
|
||||||
|
|
||||||
|
latent = sft.load_file(file, device="cpu")
|
||||||
|
|
||||||
|
return (latent, )
|
||||||
|
|
||||||
|
|
||||||
class CheckpointLoader:
|
class CheckpointLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1235,6 +1322,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
"CheckpointLoader": CheckpointLoader,
|
"CheckpointLoader": CheckpointLoader,
|
||||||
"DiffusersLoader": DiffusersLoader,
|
"DiffusersLoader": DiffusersLoader,
|
||||||
|
|
||||||
|
"LoadLatent": LoadLatent,
|
||||||
|
"SaveLatent": SaveLatent
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
Loading…
Reference in New Issue
Block a user