load_checkpoint_guess_config can now optionally output the model.

This commit is contained in:
comfyanonymous 2023-10-06 13:48:18 -04:00
parent 5b828258f1
commit 72188dffc3

View File

@ -394,13 +394,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
model = None model = None
model_patcher = None
clip_target = None clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
@ -421,6 +422,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if fp16: if fp16:
dtype = torch.float16 dtype = torch.float16
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, dtype) inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
@ -444,6 +446,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU") print("loaded straight to GPU")