load_checkpoint_guess_config can now optionally output the model.

This commit is contained in:
comfyanonymous 2023-10-06 13:48:18 -04:00
parent 5b828258f1
commit 72188dffc3

View File

@ -394,13 +394,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
model = None model = None
model_patcher = None
clip_target = None clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
@ -421,10 +422,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if fp16: if fp16:
dtype = torch.float16 dtype = torch.float16
inital_load_device = model_management.unet_inital_load_device(parameters, dtype) if output_model:
offload_device = model_management.unet_offload_device() inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) offload_device = model_management.unet_offload_device()
model.load_model_weights(sd, "model.diffusion_model.") model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae = VAE() vae = VAE()
@ -444,10 +446,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if output_model:
if inital_load_device != torch.device("cpu"): model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
print("loaded straight to GPU") if inital_load_device != torch.device("cpu"):
model_management.load_model_gpu(model_patcher) print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)