mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Add a function to load a unet from a state dict.
This commit is contained in:
parent
f30b992b18
commit
c45d1b9b67
14
comfy/sd.py
14
comfy/sd.py
@ -481,20 +481,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_unet(unet_path): #load unet in diffusers format
|
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
if "input_blocks.0.0.weight" in sd: #ldm
|
if "input_blocks.0.0.weight" in sd: #ldm
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
return None
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
|
|
||||||
else: #diffusers
|
else: #diffusers
|
||||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||||
@ -514,6 +512,14 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
print("left over keys in unet:", left_over)
|
print("left over keys in unet:", left_over)
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
|
|
||||||
|
def load_unet(unet_path):
|
||||||
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
|
model = load_unet_state_dict(sd)
|
||||||
|
if model is None:
|
||||||
|
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||||
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
|
return model
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
model_management.load_models_gpu([model, clip.load_model()])
|
model_management.load_models_gpu([model, clip.load_model()])
|
||||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||||
|
Loading…
Reference in New Issue
Block a user