Add a CLIPSave node to save CLIP model weights.

This commit is contained in:
comfyanonymous 2023-10-10 01:24:49 -04:00
parent d1a0abd40b
commit 877553843f

View File

@ -179,6 +179,62 @@ class CheckpointSave:
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {} return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()])
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
for x in k:
current_clip_sd[x] = clip_sd.pop(x)
if len(current_clip_sd) == 0:
continue
p = prefix[:-1]
replace_prefix = {}
filename_prefix_ = filename_prefix
if len(p) > 0:
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
class VAESave: class VAESave:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -220,5 +276,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd, "ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave, "CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple, "CLIPMergeSimple": CLIPMergeSimple,
"CLIPSave": CLIPSave,
"VAESave": VAESave, "VAESave": VAESave,
} }