mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
load_checkpoint_guess_config can now optionally output the model.
This commit is contained in:
parent
5b828258f1
commit
72188dffc3
@ -394,13 +394,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
sd_keys = sd.keys()
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
model = None
|
model = None
|
||||||
|
model_patcher = None
|
||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
@ -421,6 +422,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if fp16:
|
if fp16:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
|
||||||
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||||
@ -444,6 +446,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
|
if output_model:
|
||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
print("loaded straight to GPU")
|
print("loaded straight to GPU")
|
||||||
|
Loading…
Reference in New Issue
Block a user