mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
9b93b920be
The created checkpoints contain workflow metadata that can be loaded by dragging them on top of the UI or loading them with the "Load" button. Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI is using for inference on your hardware. To force fp32 use: --force-fp32 Anything that patches the model weights like merging or loras will be saved. The output directory is currently set to: output/checkpoints but that might change in the future.
96 lines
3.3 KiB
Python
96 lines
3.3 KiB
Python
import comfy.sd
|
|
import comfy.utils
|
|
import folder_paths
|
|
import json
|
|
import os
|
|
|
|
class ModelMergeSimple:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model1": ("MODEL",),
|
|
"model2": ("MODEL",),
|
|
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge"
|
|
|
|
CATEGORY = "_for_testing/model_merging"
|
|
|
|
def merge(self, model1, model2, ratio):
|
|
m = model1.clone()
|
|
sd = model2.model_state_dict("diffusion_model.")
|
|
for k in sd:
|
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
|
return (m, )
|
|
|
|
class ModelMergeBlocks:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model1": ("MODEL",),
|
|
"model2": ("MODEL",),
|
|
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge"
|
|
|
|
CATEGORY = "_for_testing/model_merging"
|
|
|
|
def merge(self, model1, model2, **kwargs):
|
|
m = model1.clone()
|
|
sd = model2.model_state_dict("diffusion_model.")
|
|
default_ratio = next(iter(kwargs.values()))
|
|
|
|
for k in sd:
|
|
ratio = default_ratio
|
|
k_unet = k[len("diffusion_model."):]
|
|
|
|
for arg in kwargs:
|
|
if k_unet.startswith(arg):
|
|
ratio = kwargs[arg]
|
|
|
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
|
return (m, )
|
|
|
|
class CheckpointSave:
|
|
def __init__(self):
|
|
self.output_dir = folder_paths.get_output_directory()
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model": ("MODEL",),
|
|
"clip": ("CLIP",),
|
|
"vae": ("VAE",),
|
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
|
|
CATEGORY = "_for_testing/model_merging"
|
|
|
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
prompt_info = ""
|
|
if prompt is not None:
|
|
prompt_info = json.dumps(prompt)
|
|
|
|
metadata = {"prompt": prompt_info}
|
|
if extra_pnginfo is not None:
|
|
for x in extra_pnginfo:
|
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
|
|
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
|
|
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
|
return {}
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"ModelMergeSimple": ModelMergeSimple,
|
|
"ModelMergeBlocks": ModelMergeBlocks,
|
|
"CheckpointSave": CheckpointSave,
|
|
}
|