diff --git a/comfy/sd.py b/comfy/sd.py index f186273e..cfd6fb3c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -394,13 +394,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): sd = comfy.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None vae = None model = None + model_patcher = None clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") @@ -421,10 +422,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if fp16: dtype = torch.float16 - inital_load_device = model_management.unet_inital_load_device(parameters, dtype) - offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) - model.load_model_weights(sd, "model.diffusion_model.") + if output_model: + inital_load_device = model_management.unet_inital_load_device(parameters, dtype) + offload_device = model_management.unet_offload_device() + model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) + model.load_model_weights(sd, "model.diffusion_model.") if output_vae: vae = VAE() @@ -444,10 +446,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) - if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") - model_management.load_model_gpu(model_patcher) + if output_model: + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + if inital_load_device != torch.device("cpu"): + print("loaded straight to GPU") + model_management.load_model_gpu(model_patcher) return (model_patcher, clip, vae, clipvision)