diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 965e5e7a..bfa787d3 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -5,14 +5,12 @@ import model_management from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions import torch import comfy.utils +import folder_paths class UpscaleModelLoader: - models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "models") - upscale_model_dir = os.path.join(models_dir, "upscale_models") - @classmethod def INPUT_TYPES(s): - return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir), supported_ckpt_extensions), ), + return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), }} RETURN_TYPES = ("UPSCALE_MODEL",) FUNCTION = "load_model" @@ -20,7 +18,7 @@ class UpscaleModelLoader: CATEGORY = "loaders" def load_model(self, model_name): - model_path = os.path.join(self.upscale_model_dir, model_name) + model_path = folder_paths.get_full_path("upscale_models", model_name) sd = load_torch_file(model_path) out = model_loading.load_state_dict(sd).eval() return (out, ) diff --git a/folder_paths.py b/folder_paths.py new file mode 100644 index 00000000..b2752e15 --- /dev/null +++ b/folder_paths.py @@ -0,0 +1,65 @@ +import os + +supported_ckpt_extensions = set(['.ckpt', '.pth']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) +try: + import safetensors.torch + supported_ckpt_extensions.add('.safetensors') + supported_pt_extensions.add('.safetensors') +except: + print("Could not import safetensors, safetensors support disabled.") + + +folder_names_and_paths = {} + + +models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") +folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) +folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) + +folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) +folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) +folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) +folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) +folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) +# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) + +folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) +folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) + + +def add_model_folder(folder_name, full_folder_path): + global folder_names_and_paths + + +def recursive_search(directory): + result = [] + for root, subdir, file in os.walk(directory, followlinks=True): + for filepath in file: + #we os.path,join directory with a blank string to generate a path separator at the end. + result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) + return result + +def filter_files_extensions(files, extensions): + return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) + + + +def get_full_path(folder_name, filename): + global folder_names_and_paths + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + full_path = os.path.join(x, filename) + if os.path.isfile(full_path): + return full_path + + +def get_filename_list(folder_name): + global folder_names_and_paths + output_list = [] + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + output_list += filter_files_extensions(recursive_search(x), folders[1]) + return output_list + + diff --git a/nodes.py b/nodes.py index 9a878b44..93268240 100644 --- a/nodes.py +++ b/nodes.py @@ -23,6 +23,7 @@ import comfy_extras.clip_vision import model_management import importlib +import folder_paths supported_ckpt_extensions = ['.ckpt', '.pth'] supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] try: @@ -208,31 +209,26 @@ class VAEEncodeForInpaint: class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - config_dir = os.path.join(models_dir, "configs") - ckpt_dir = os.path.join(models_dir, "checkpoints") embedding_directory = os.path.join(models_dir, "embeddings") @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ), - "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}} + return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): - config_path = os.path.join(self.config_dir, config_name) - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + config_path = folder_paths.get_full_path("configs", config_name) + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) class CheckpointLoaderSimple: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - ckpt_dir = os.path.join(models_dir, "checkpoints") - @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ), + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -240,7 +236,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) return out @@ -261,13 +257,11 @@ class CLIPSetLastLayer: return (clip,) class LoraLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - lora_dir = os.path.join(models_dir, "loras") @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), - "lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ), + "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} @@ -277,16 +271,14 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): - lora_path = os.path.join(self.lora_dir, lora_name) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) class VAELoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - vae_dir = os.path.join(models_dir, "vae") @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}} + return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -294,16 +286,14 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - vae_path = os.path.join(self.vae_dir, vae_name) + vae_path = folder_paths.get_full_path("vae", vae_name) vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) class ControlNetLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - controlnet_dir = os.path.join(models_dir, "controlnet") @classmethod def INPUT_TYPES(s): - return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -311,17 +301,15 @@ class ControlNetLoader: CATEGORY = "loaders" def load_controlnet(self, control_net_name): - controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet = comfy.sd.load_controlnet(controlnet_path) return (controlnet,) class DiffControlNetLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - controlnet_dir = os.path.join(models_dir, "controlnet") @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -329,7 +317,7 @@ class DiffControlNetLoader: CATEGORY = "loaders" def load_controlnet(self, model, control_net_name): - controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet = comfy.sd.load_controlnet(controlnet_path, model) return (controlnet,) @@ -378,11 +366,9 @@ class T2IAdapterLoader: return (t2i_adapter,) class CLIPLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - clip_dir = os.path.join(models_dir, "clip") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -390,16 +376,14 @@ class CLIPLoader: CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = os.path.join(self.clip_dir, clip_name) + clip_path = folder_paths.get_full_path("clip", clip_name) clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) return (clip,) class CLIPVisionLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - clip_dir = os.path.join(models_dir, "clip_vision") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ), }} RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -407,7 +391,7 @@ class CLIPVisionLoader: CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = os.path.join(self.clip_dir, clip_name) + clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_vision = comfy_extras.clip_vision.load(clip_path) return (clip_vision,) @@ -427,11 +411,9 @@ class CLIPVisionEncode: return (output,) class StyleModelLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - style_model_dir = os.path.join(models_dir, "style_models") @classmethod def INPUT_TYPES(s): - return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}} + return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" @@ -439,7 +421,7 @@ class StyleModelLoader: CATEGORY = "loaders" def load_style_model(self, style_model_name): - style_model_path = os.path.join(self.style_model_dir, style_model_name) + style_model_path = folder_paths.get_full_path("style_models", style_model_name) style_model = comfy.sd.load_style_model(style_model_path) return (style_model,)