mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Support properly saving CosXL checkpoints.
This commit is contained in:
parent
d644b6bcd8
commit
30abc324c2
@ -600,7 +600,7 @@ def load_unet(unet_path):
|
|||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
clip_sd = None
|
clip_sd = None
|
||||||
load_models = [model]
|
load_models = [model]
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
model_management.load_models_gpu(load_models)
|
model_management.load_models_gpu(load_models)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||||
|
for k in extra_keys:
|
||||||
|
sd[k] = extra_keys[k]
|
||||||
|
|
||||||
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
|
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
|
@ -2,7 +2,9 @@ import comfy.sd
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.model_sampling
|
||||||
|
|
||||||
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||||
# "v2-inpainting"
|
# "v2-inpainting"
|
||||||
|
|
||||||
|
extra_keys = {}
|
||||||
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
|
||||||
|
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
|
||||||
|
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
|
||||||
|
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
|
||||||
|
|
||||||
if model.model.model_type == comfy.model_base.ModelType.EPS:
|
if model.model.model_type == comfy.model_base.ModelType.EPS:
|
||||||
metadata["modelspec.predict_key"] = "epsilon"
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||||
|
|
||||||
class CheckpointSave:
|
class CheckpointSave:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user