diff --git a/nodes.py b/nodes.py index d3fe5f46..efa9522d 100644 --- a/nodes.py +++ b/nodes.py @@ -26,6 +26,14 @@ try: except: print("Could not import safetensors, safetensors support disabled.") +def recursive_search(directory): + result = [] + for root, subdir, file in os.walk(directory, followlinks=True): + for filepath in file: + #we remove the first character to remove the path separator. + result.append(os.path.join(root, filepath).replace(directory,'')[1:]) + return result + def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -119,8 +127,8 @@ class CheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": (filter_files_extensions(os.listdir(s.config_dir), '.yaml'), ), - "ckpt_name": (filter_files_extensions(os.listdir(s.ckpt_dir), supported_ckpt_extensions), )}} + return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ), + "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -138,7 +146,7 @@ class LoraLoader: def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), - "lora_name": (filter_files_extensions(os.listdir(s.lora_dir), supported_pt_extensions), ), + "lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} @@ -157,7 +165,7 @@ class VAELoader: vae_dir = os.path.join(models_dir, "vae") @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (filter_files_extensions(os.listdir(s.vae_dir), supported_pt_extensions), )}} + return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -174,7 +182,7 @@ class CLIPLoader: clip_dir = os.path.join(models_dir, "clip") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(os.listdir(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), }} RETURN_TYPES = ("CLIP",)