Implement modelspec metadata in CheckpointSave for SDXL and refiner.

This commit is contained in:
comfyanonymous 2023-07-25 22:02:26 -04:00
parent 727588d076
commit 5e3ac1928a

View File

@ -1,5 +1,7 @@
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_base
import folder_paths import folder_paths
import json import json
import os import os
@ -100,6 +102,31 @@ class CheckpointSave:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info} metadata = {"prompt": prompt_info}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x]) metadata[x] = json.dumps(extra_pnginfo[x])