Preparing to add another function to load checkpoints.

This commit is contained in:
comfyanonymous 2023-02-26 17:04:35 -05:00
parent 3cd7d84b53
commit 9f4214e534

View File

@ -26,12 +26,7 @@ def load_torch_file(ckpt):
sd = pl_sd
return sd
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
print(f"Loading model from {ckpt}")
sd = load_torch_file(ckpt)
model = instantiate_from_config(config.model)
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False)
k = list(sd.keys())
@ -654,5 +649,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
model = instantiate_from_config(config.model)
sd = load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
return (ModelPatcher(model), clip, vae)