Add folder_paths so models can be in multiple paths.

This commit is contained in:
comfyanonymous 2023-03-17 17:57:57 -04:00
parent 51bbbf8d64
commit e1a9e26968
3 changed files with 89 additions and 44 deletions

View File

@ -5,14 +5,12 @@ import model_management
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
import torch import torch
import comfy.utils import comfy.utils
import folder_paths
class UpscaleModelLoader: class UpscaleModelLoader:
models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "models")
upscale_model_dir = os.path.join(models_dir, "upscale_models")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir), supported_ckpt_extensions), ), return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
}} }}
RETURN_TYPES = ("UPSCALE_MODEL",) RETURN_TYPES = ("UPSCALE_MODEL",)
FUNCTION = "load_model" FUNCTION = "load_model"
@ -20,7 +18,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_model(self, model_name): def load_model(self, model_name):
model_path = os.path.join(self.upscale_model_dir, model_name) model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = load_torch_file(model_path) sd = load_torch_file(model_path)
out = model_loading.load_state_dict(sd).eval() out = model_loading.load_state_dict(sd).eval()
return (out, ) return (out, )

65
folder_paths.py Normal file
View File

@ -0,0 +1,65 @@
import os
supported_ckpt_extensions = set(['.ckpt', '.pth'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
try:
import safetensors.torch
supported_ckpt_extensions.add('.safetensors')
supported_pt_extensions.add('.safetensors')
except:
print("Could not import safetensors, safetensors support disabled.")
folder_names_and_paths = {}
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
def add_model_folder(folder_name, full_folder_path):
global folder_names_and_paths
def recursive_search(directory):
result = []
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
def get_full_path(folder_name, filename):
global folder_names_and_paths
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
full_path = os.path.join(x, filename)
if os.path.isfile(full_path):
return full_path
def get_filename_list(folder_name):
global folder_names_and_paths
output_list = []
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
output_list += filter_files_extensions(recursive_search(x), folders[1])
return output_list

View File

@ -23,6 +23,7 @@ import comfy_extras.clip_vision
import model_management import model_management
import importlib import importlib
import folder_paths
supported_ckpt_extensions = ['.ckpt', '.pth'] supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try: try:
@ -208,31 +209,26 @@ class VAEEncodeForInpaint:
class CheckpointLoader: class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
config_dir = os.path.join(models_dir, "configs")
ckpt_dir = os.path.join(models_dir, "checkpoints")
embedding_directory = os.path.join(models_dir, "embeddings") embedding_directory = os.path.join(models_dir, "embeddings")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ), return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}} "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = os.path.join(self.config_dir, config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
ckpt_dir = os.path.join(models_dir, "checkpoints")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ), return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}} }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
@ -240,7 +236,7 @@ class CheckpointLoaderSimple:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
return out return out
@ -261,13 +257,11 @@ class CLIPSetLastLayer:
return (clip,) return (clip,)
class LoraLoader: class LoraLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
lora_dir = os.path.join(models_dir, "loras")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ), "clip": ("CLIP", ),
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ), "lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}} }}
@ -277,16 +271,14 @@ class LoraLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip): def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
lora_path = os.path.join(self.lora_dir, lora_name) lora_path = folder_paths.get_full_path("loras", lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class VAELoader: class VAELoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
vae_dir = os.path.join(models_dir, "vae")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}} return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
RETURN_TYPES = ("VAE",) RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae" FUNCTION = "load_vae"
@ -294,16 +286,14 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
vae_path = os.path.join(self.vae_dir, vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path) vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
RETURN_TYPES = ("CONTROL_NET",) RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet" FUNCTION = "load_controlnet"
@ -311,17 +301,15 @@ class ControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, control_net_name): def load_controlnet(self, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name) controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path) controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,) return (controlnet,)
class DiffControlNetLoader: class DiffControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
RETURN_TYPES = ("CONTROL_NET",) RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet" FUNCTION = "load_controlnet"
@ -329,7 +317,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name): def load_controlnet(self, model, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name) controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path, model) controlnet = comfy.sd.load_controlnet(controlnet_path, model)
return (controlnet,) return (controlnet,)
@ -378,11 +366,9 @@ class T2IAdapterLoader:
return (t2i_adapter,) return (t2i_adapter,)
class CLIPLoader: class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@ -390,16 +376,14 @@ class CLIPLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = os.path.join(self.clip_dir, clip_name) clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
return (clip,) return (clip,)
class CLIPVisionLoader: class CLIPVisionLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip_vision")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
}} }}
RETURN_TYPES = ("CLIP_VISION",) RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@ -407,7 +391,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = os.path.join(self.clip_dir, clip_name) clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_vision = comfy_extras.clip_vision.load(clip_path) clip_vision = comfy_extras.clip_vision.load(clip_path)
return (clip_vision,) return (clip_vision,)
@ -427,11 +411,9 @@ class CLIPVisionEncode:
return (output,) return (output,)
class StyleModelLoader: class StyleModelLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
style_model_dir = os.path.join(models_dir, "style_models")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}} return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
RETURN_TYPES = ("STYLE_MODEL",) RETURN_TYPES = ("STYLE_MODEL",)
FUNCTION = "load_style_model" FUNCTION = "load_style_model"
@ -439,7 +421,7 @@ class StyleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_style_model(self, style_model_name): def load_style_model(self, style_model_name):
style_model_path = os.path.join(self.style_model_dir, style_model_name) style_model_path = folder_paths.get_full_path("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path) style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,) return (style_model,)