diff --git a/app/model_manager.py b/app/model_manager.py new file mode 100644 index 00000000..475970d1 --- /dev/null +++ b/app/model_manager.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import os +import time +import logging +import folder_paths +import glob +from aiohttp import web +from PIL import Image +from io import BytesIO +from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types + + +class ModelFileManager: + def __init__(self) -> None: + self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} + + def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: + return self.cache.get(key, default) + + def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]): + self.cache[key] = value + + def clear_cache(self): + self.cache.clear() + + def add_routes(self, routes): + # NOTE: This is an experiment to replace `/models` + @routes.get("/experiment/models") + async def get_model_folders(request): + model_types = list(folder_paths.folder_names_and_paths.keys()) + folder_black_list = ["configs", "custom_nodes"] + output_folders: list[dict] = [] + for folder in model_types: + if folder in folder_black_list: + continue + output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) + return web.json_response(output_folders) + + # NOTE: This is an experiment to replace `/models/{folder}` + @routes.get("/experiment/models/{folder}") + async def get_all_models(request): + folder = request.match_info.get("folder", None) + if not folder in folder_paths.folder_names_and_paths: + return web.Response(status=404) + files = self.get_model_file_list(folder) + return web.json_response(files) + + @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}") + async def get_model_preview(request): + folder_name = request.match_info.get("folder", None) + path_index = int(request.match_info.get("path_index", None)) + filename = request.match_info.get("filename", None) + + if not folder_name in folder_paths.folder_names_and_paths: + return web.Response(status=404) + + folders = folder_paths.folder_names_and_paths[folder_name] + folder = folders[0][path_index] + full_filename = os.path.join(folder, filename) + + preview_files = self.get_model_previews(full_filename) + default_preview_file = preview_files[0] if len(preview_files) > 0 else None + if default_preview_file is None or not os.path.isfile(default_preview_file): + return web.Response(status=404) + + try: + with Image.open(default_preview_file) as img: + img_bytes = BytesIO() + img.save(img_bytes, format="WEBP") + img_bytes.seek(0) + return web.Response(body=img_bytes.getvalue(), content_type="image/webp") + except: + return web.Response(status=404) + + def get_model_file_list(self, folder_name: str): + folder_name = map_legacy(folder_name) + folders = folder_paths.folder_names_and_paths[folder_name] + output_list: list[dict] = [] + + for index, folder in enumerate(folders[0]): + if not os.path.isdir(folder): + continue + out = self.cache_model_file_list_(folder) + if out is None: + out = self.recursive_search_models_(folder, index) + self.set_cache(folder, out) + output_list.extend(out[0]) + + return output_list + + def cache_model_file_list_(self, folder: str): + model_file_list_cache = self.get_cache(folder) + + if model_file_list_cache is None: + return None + if not os.path.isdir(folder): + return None + if os.path.getmtime(folder) != model_file_list_cache[1]: + return None + for x in model_file_list_cache[1]: + time_modified = model_file_list_cache[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + return model_file_list_cache + + def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]: + if not os.path.isdir(directory): + return [], {}, time.perf_counter() + + excluded_dir_names = [".git"] + # TODO use settings + include_hidden_files = False + + result: list[str] = [] + dirs: dict[str, float] = {} + + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): + subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] + if not include_hidden_files: + subdirs[:] = [d for d in subdirs if not d.startswith(".")] + filenames = [f for f in filenames if not f.startswith(".")] + + filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions) + + for file_name in filenames: + try: + relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) + result.append(relative_path) + except: + logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") + continue + + for d in subdirs: + path: str = os.path.join(dirpath, d) + try: + dirs[path] = os.path.getmtime(path) + except FileNotFoundError: + logging.warning(f"Warning: Unable to access {path}. Skipping this path.") + continue + + return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() + + def get_model_previews(self, filepath: str) -> list[str]: + dirname = os.path.dirname(filepath) + + if not os.path.exists(dirname): + return [] + + basename = os.path.splitext(filepath)[0] + match_files = glob.glob(f"{basename}.*", recursive=False) + image_files = filter_files_content_types(match_files, "image") + + result: list[str] = [] + + for filename in image_files: + _basename = os.path.splitext(filename)[0] + if _basename == basename: + result.append(filename) + if _basename == f"{basename}.preview": + result.append(filename) + return result + + def __exit__(self, exc_type, exc_value, traceback): + self.clear_cache() diff --git a/server.py b/server.py index 5e86558d..86800984 100644 --- a/server.py +++ b/server.py @@ -29,6 +29,7 @@ import comfy.model_management import node_helpers from app.frontend_management import FrontendManager from app.user_manager import UserManager +from app.model_manager import ModelFileManager from typing import Optional from api_server.routes.internal.internal_routes import InternalRoutes @@ -151,6 +152,7 @@ class PromptServer(): mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' self.user_manager = UserManager() + self.model_file_manager = ModelFileManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = None @@ -220,7 +222,7 @@ class PromptServer(): def get_embeddings(self): embeddings = folder_paths.get_filename_list("embeddings") return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) - + @routes.get("/models") def list_model_types(request): model_types = list(folder_paths.folder_names_and_paths.keys()) @@ -682,6 +684,7 @@ class PromptServer(): def add_routes(self): self.user_manager.add_routes(self.routes) + self.model_file_manager.add_routes(self.routes) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation.