diff --git a/nodes.py b/nodes.py index aaec64f8..7e737d0c 100644 --- a/nodes.py +++ b/nodes.py @@ -17,9 +17,11 @@ import comfy.samplers import comfy.sd supported_ckpt_extensions = ['.ckpt'] +supported_pt_extensions = ['.ckpt', '.pt'] try: import safetensors.torch supported_ckpt_extensions += ['.safetensors'] + supported_pt_extensions += ['.safetensors'] except: print("Could not import safetensors, safetensors support disabled.") @@ -132,7 +134,7 @@ class VAELoader: vae_dir = os.path.join(models_dir, "vae") @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (filter_files_extensions(os.listdir(s.vae_dir), supported_ckpt_extensions), )}} + return {"required": { "vae_name": (filter_files_extensions(os.listdir(s.vae_dir), supported_pt_extensions), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae"