From 9f4214e53413915fcaf7e6146f1e90d302b6a9a2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 26 Feb 2023 17:04:35 -0500 Subject: [PATCH] Preparing to add another function to load checkpoints. --- comfy/sd.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e5baddca..0b55b8b1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -26,12 +26,7 @@ def load_torch_file(ckpt): sd = pl_sd return sd -def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): - print(f"Loading model from {ckpt}") - - sd = load_torch_file(ckpt) - model = instantiate_from_config(config.model) - +def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) k = list(sd.keys()) @@ -654,5 +649,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] - model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to) + model = instantiate_from_config(config.model) + sd = load_torch_file(ckpt_path) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) return (ModelPatcher(model), clip, vae)