mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
load_checkpoint_guess_config can now optionally output the model.
This commit is contained in:
parent
5b828258f1
commit
72188dffc3
21
comfy/sd.py
21
comfy/sd.py
@ -394,13 +394,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
clip_target = None
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
@ -421,10 +422,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae = VAE()
|
||||
@ -444,10 +446,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if len(left_over) > 0:
|
||||
print("left over keys:", left_over)
|
||||
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user