From 58ed0f2da438aaf253f9880578d694ad917819f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Apr 2023 01:28:15 -0400 Subject: [PATCH] Fix loading SD1.5 diffusers checkpoint. --- comfy/diffusers_convert.py | 4 +++- nodes.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index a31c1c11..950137f2 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -272,7 +272,8 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb # magic v2 = diffusers_unet_conf["sample_size"] == 96 - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' if v2: if v_pred: @@ -290,6 +291,7 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb scale_factor = model_config_params['scale_factor'] vae_config = model_config_params['first_stage_config'] vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") diff --git a/nodes.py b/nodes.py index 274ae2f1..025e4fcb 100644 --- a/nodes.py +++ b/nodes.py @@ -231,7 +231,7 @@ class DiffusersLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "advanced/loaders" def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): @@ -240,7 +240,7 @@ class DiffusersLoader: if model_path in paths: model_path = os.path.join(search_path, model_path) break - search_paths = folder_paths.get_folder_paths("diffusers") + return load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))