From 9b93b920bee8a390a4242326cf6380d77f83e8de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Jun 2023 12:21:07 -0400 Subject: [PATCH] Add CheckpointSave node to save checkpoints. The created checkpoints contain workflow metadata that can be loaded by dragging them on top of the UI or loading them with the "Load" button. Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI is using for inference on your hardware. To force fp32 use: --force-fp32 Anything that patches the model weights like merging or loras will be saved. The output directory is currently set to: output/checkpoints but that might change in the future. --- comfy/diffusers_convert.py | 4 ++- comfy/model_base.py | 12 ++++++++ comfy/sd.py | 32 +++++++++++++++++++-- comfy/supported_models.py | 29 +++++++++++++++++++ comfy/supported_models_base.py | 12 ++++++++ comfy/utils.py | 14 ++++++++- comfy_extras/nodes_model_merging.py | 44 +++++++++++++++++++++++++++-- nodes.py | 3 +- notebooks/comfyui_colab.ipynb | 1 + web/scripts/app.js | 2 +- web/scripts/pnginfo.js | 5 ++-- web/scripts/ui.js | 2 +- 12 files changed, 147 insertions(+), 13 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 1eab54d4..9688cbd5 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys())) code2idx = {"q": 0, "k": 1, "v": 2} -def convert_text_enc_state_dict_v20(text_enc_dict): +def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} capture_qkv_weight = {} capture_qkv_bias = {} for k, v in text_enc_dict.items(): + if not k.startswith(prefix): + continue if ( k.endswith(".self_attn.q_proj.weight") or k.endswith(".self_attn.k_proj.weight") diff --git a/comfy/model_base.py b/comfy/model_base.py index 923c4348..e4c9391d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np +from . import utils class BaseModel(torch.nn.Module): def __init__(self, model_config, v_prediction=False): @@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format + self.model_config = model_config self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) self.v_prediction = v_prediction @@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) + def state_dict_for_saving(self, clip_state_dict, vae_state_dict): + clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + unet_state_dict = self.diffusion_model.state_dict() + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) + vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: + clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) + vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, v_prediction=True): diff --git a/comfy/sd.py b/comfy/sd.py index dbfbdbe3..21d7b8a5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -545,11 +545,11 @@ class CLIP: if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: - self.patcher.patch_model() + self.patch_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) - self.patcher.unpatch_model() + self.unpatch_model() except Exception as e: - self.patcher.unpatch_model() + self.unpatch_model() raise e cond_out = cond @@ -564,6 +564,15 @@ class CLIP: def load_sd(self, sd): return self.cond_stage_model.load_sd(sd) + def get_sd(self): + return self.cond_stage_model.state_dict() + + def patch_model(self): + self.patcher.patch_model() + + def unpatch_model(self): + self.patcher.unpatch_model() + class VAE: def __init__(self, ckpt_path=None, device=None, config=None): if config is None: @@ -665,6 +674,10 @@ class VAE: self.first_stage_model = self.first_stage_model.cpu() return samples + def get_sd(self): + return self.first_stage_model.state_dict() + + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) return (ModelPatcher(model), clip, vae, clipvision) + +def save_checkpoint(output_path, model, clip, vae, metadata=None): + try: + model.patch_model() + clip.patch_model() + sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) + utils.save_torch_file(sd, output_path, metadata=metadata) + model.unpatch_model() + clip.unpatch_model() + except Exception as e: + model.unpatch_model() + clip.unpatch_model() + raise e diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 51da9456..6b17b089 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -9,6 +9,8 @@ from . import sdxl_clip from . import supported_models_base from . import latent_formats +from . import diffusers_convert + class SD15(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE): state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + replace_prefix[""] = "cond_stage_model.model." + state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) + return state_dict + def clip_target(self): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) @@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + replace_prefix["clip_g"] = "conditioner.embedders.0.model" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) @@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + keys_to_replace = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + for k in state_dict: + if k.startswith("clip_l"): + state_dict_g[k] = state_dict[k] + + replace_prefix["clip_g"] = "conditioner.embedders.1.model" + replace_prefix["clip_l"] = "conditioner.embedders.0" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3312a99d..0b0235ca 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -64,3 +64,15 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "cond_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model.diffusion_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_vae_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "first_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + diff --git a/comfy/utils.py b/comfy/utils.py index 7a7f1fa1..b6434905 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -2,10 +2,10 @@ import torch import math import struct import comfy.checkpoint_pickle +import safetensors.torch def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): - import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: if safe_load: @@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False): sd = pl_sd return sd +def save_torch_file(sd, ckpt, metadata=None): + if metadata is not None: + safetensors.torch.save_file(sd, ckpt, metadata=metadata) + else: + safetensors.torch.save_file(sd, ckpt) + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def convert_sd_to(state_dict, dtype): + keys = list(state_dict.keys()) + for k in keys: + state_dict[k] = state_dict[k].to(dtype) + return state_dict + def safetensors_header(safetensors_path, max_size=100*1024*1024): with open(safetensors_path, "rb") as f: header = f.read(8) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 52b73f70..4f71b203 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,4 +1,8 @@ - +import comfy.sd +import comfy.utils +import folder_paths +import json +import os class ModelMergeSimple: @classmethod @@ -49,7 +53,43 @@ class ModelMergeBlocks: m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) return (m, ) +class CheckpointSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "_for_testing/model_merging" + + def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + return {} + + NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, - "ModelMergeBlocks": ModelMergeBlocks + "ModelMergeBlocks": ModelMergeBlocks, + "CheckpointSave": CheckpointSave, } diff --git a/nodes.py b/nodes.py index 7280d788..3c096009 100644 --- a/nodes.py +++ b/nodes.py @@ -286,8 +286,7 @@ class SaveLatent: output["latent_tensor"] = samples["samples"] output["latent_format_version_0"] = torch.tensor([]) - safetensors.torch.save_file(output, file, metadata=metadata) - + comfy.utils.save_torch_file(output, file, metadata=metadata) return {} diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c5a209ee..61c277bf 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -144,6 +144,7 @@ "\n", "\n", "# ESRGAN upscale model\n", + "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", "\n", diff --git a/web/scripts/app.js b/web/scripts/app.js index 4e83c40a..09310c7f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1468,7 +1468,7 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); - } else if (file.name?.endsWith(".latent")) { + } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { const info = await getLatentMetadata(file); if (info.workflow) { this.loadGraphData(JSON.parse(info.workflow)); diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 977b5ac2..c5293dfa 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -55,11 +55,12 @@ export function getLatentMetadata(file) { const dataView = new DataView(safetensorsData.buffer); let header_size = dataView.getUint32(0, true); let offset = 8; - let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size))); r(header.__metadata__); }; - reader.readAsArrayBuffer(file); + var slice = file.slice(0, 1024 * 1024 * 4); + reader.readAsArrayBuffer(slice); }); } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 99e9123a..12fda127 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -545,7 +545,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent", + accept: ".json,image/png,.latent,.safetensors", style: {display: "none"}, parent: document.body, onchange: () => {