From 7bf381bc9e1adbb15b5e89938f48ca9742c21f79 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 28 Mar 2025 11:39:56 +0800 Subject: [PATCH] Add model management and database - use sqlalchemy + alembic + sqlite for db - extract model data and previews - endpoints for db interactions - add tests --- alembic.ini | 119 +++++++++ alembic_db/README.md | 3 + alembic_db/env.py | 75 ++++++ alembic_db/script.py.mako | 28 ++ alembic_db/versions/2fb22c4fff36_init.py | 58 ++++ app/database/db.py | 116 ++++++++ app/database/models.py | 76 ++++++ app/model_manager.py | 234 ++++++++++++++--- app/model_processor.py | 263 +++++++++++++++++++ comfy/cli_args.py | 6 + main.py | 6 + requirements.txt | 2 + tests-unit/app_test/model_manager_test.py | 306 ++++++++++++++++++++++ utils/web.py | 12 + 14 files changed, 1264 insertions(+), 40 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic_db/README.md create mode 100644 alembic_db/env.py create mode 100644 alembic_db/script.py.mako create mode 100644 alembic_db/versions/2fb22c4fff36_init.py create mode 100644 app/database/db.py create mode 100644 app/database/models.py create mode 100644 app/model_processor.py create mode 100644 utils/web.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..cd1924956 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,119 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000..4d12f1ed6 --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,3 @@ +## Generate new revision +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000..6903eedde --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,75 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000..480b130d6 --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic_db/versions/2fb22c4fff36_init.py b/alembic_db/versions/2fb22c4fff36_init.py new file mode 100644 index 000000000..a21636c21 --- /dev/null +++ b/alembic_db/versions/2fb22c4fff36_init.py @@ -0,0 +1,58 @@ +"""init + +Revision ID: 2fb22c4fff36 +Revises: +Create Date: 2025-03-27 19:00:47.686079 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '2fb22c4fff36' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('model', + sa.Column('type', sa.Text(), nullable=False), + sa.Column('path', sa.Text(), nullable=False), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('architecture', sa.Text(), nullable=True), + sa.Column('hash', sa.Text(), nullable=True), + sa.Column('source_url', sa.Text(), nullable=True), + sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.PrimaryKeyConstraint('type', 'path') + ) + op.create_table('tag', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + op.create_table('model_tag', + sa.Column('model_type', sa.Text(), nullable=False), + sa.Column('model_path', sa.Text(), nullable=False), + sa.Column('tag_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['model_type', 'model_path'], ['model.type', 'model.path'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('model_type', 'model_path', 'tag_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('model_tag') + op.drop_table('tag') + op.drop_table('model') + # ### end Alembic commands ### diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000..dceab5030 --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,116 @@ +import logging +import os +import shutil +import sys +from app.database.models import Tag +from comfy.cli_args import args + +try: + import alembic + import sqlalchemy +except ImportError as e: + req_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..", "requirements.txt") + ) + logging.error( + f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n" + ) + exit(-1) + +from alembic import command +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +Session = None + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if current_rev != target_rev: + # Backup the database pre upgrade + db_path = get_db_path() + backup_path = db_path + ".bkp" + if os.path.exists(db_path): + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.error(f"Error upgrading database: {e}") + raise e + + global Session + Session = sessionmaker(bind=engine) + + if not current_rev: + # Init db, populate models + from app.model_processor import model_processor + + session = create_session() + model_processor.populate_models(session) + + # populate tags + tags = ( + "character", + "style", + "concept", + "clothing", + "pose", + "background", + "vehicle", + "object", + "animal", + "action", + ) + for tag in tags: + session.add(Tag(name=tag)) + + session.commit() + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000..52cb89bec --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,76 @@ +from sqlalchemy import ( + Column, + Integer, + Text, + DateTime, + Table, + ForeignKeyConstraint, +) +from sqlalchemy.orm import relationship, declarative_base +from sqlalchemy.sql import func + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + + +ModelTag = Table( + "model_tag", + Base.metadata, + Column( + "model_type", + Text, + primary_key=True, + ), + Column( + "model_path", + Text, + primary_key=True, + ), + Column("tag_id", Integer, primary_key=True), + ForeignKeyConstraint( + ["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE" + ), + ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"), +) + + +class Model(Base): + __tablename__ = "model" + + type = Column(Text, primary_key=True) + path = Column(Text, primary_key=True) + title = Column(Text) + description = Column(Text) + architecture = Column(Text) + hash = Column(Text) + source_url = Column(Text) + date_added = Column(DateTime, server_default=func.now()) + + # Relationship with tags + tags = relationship("Tag", secondary=ModelTag, back_populates="models") + + def to_dict(self): + dict = to_dict(self) + dict["tags"] = [tag.to_dict() for tag in self.tags] + return dict + + +class Tag(Base): + __tablename__ = "tag" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False, unique=True) + + # Relationship with models + models = relationship("Model", secondary=ModelTag, back_populates="tags") + + def to_dict(self): + return to_dict(self) diff --git a/app/model_manager.py b/app/model_manager.py index 74d942fb8..592246481 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -1,19 +1,30 @@ from __future__ import annotations import os -import base64 -import json import time import logging +from app.database.db import create_session import folder_paths -import glob -import comfy.utils from aiohttp import web from PIL import Image from io import BytesIO -from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types +from folder_paths import map_legacy, filter_files_extensions, get_full_path +from app.database.models import Tag, Model +from app.model_processor import get_model_previews, model_processor +from utils.web import dumps +from sqlalchemy.orm import joinedload +import sqlalchemy.exc +def bad_request(message: str): + return web.json_response({"error": message}, status=400) + +def missing_field(field: str): + return bad_request(f"{field} is required") + +def not_found(message: str): + return web.json_response({"error": message + " not found"}, status=404) + class ModelFileManager: def __init__(self) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} @@ -62,7 +73,7 @@ class ModelFileManager: folder = folders[0][path_index] full_filename = os.path.join(folder, filename) - previews = self.get_model_previews(full_filename) + previews = get_model_previews(full_filename) default_preview = previews[0] if len(previews) > 0 else None if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): return web.Response(status=404) @@ -76,6 +87,183 @@ class ModelFileManager: except: return web.Response(status=404) + @routes.get("/v2/models") + async def get_models(request): + with create_session() as session: + model_path = request.query.get("path", None) + model_type = request.query.get("type", None) + query = session.query(Model).options(joinedload(Model.tags)) + if model_path: + query = query.filter(Model.path == model_path) + if model_type: + query = query.filter(Model.type == model_type) + models = query.all() + if model_path and model_type: + if len(models) == 0: + return not_found("Model") + return web.json_response(models[0].to_dict(), dumps=dumps) + + return web.json_response([model.to_dict() for model in models], dumps=dumps) + + @routes.post("/v2/models") + async def add_model(request): + with create_session() as session: + data = await request.json() + model_type = data.get("type", None) + model_path = data.get("path", None) + + if not model_type: + return missing_field("type") + if not model_path: + return missing_field("path") + + tags = data.pop("tags", []) + fields = Model.metadata.tables["model"].columns.keys() + + # Validate keys are valid model fields + for key in data.keys(): + if key not in fields: + return bad_request(f"Invalid field: {key}") + + # Validate file exists + if not get_full_path(model_type, model_path): + return not_found(f"File '{model_type}/{model_path}'") + + model = Model() + for field in fields: + if field in data: + setattr(model, field, data[field]) + + model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all() + for tag in tags: + if tag not in [t.id for t in model.tags]: + return not_found(f"Tag '{tag}'") + + try: + session.add(model) + session.commit() + except sqlalchemy.exc.IntegrityError as e: + session.rollback() + return bad_request(e.orig.args[0]) + + model_processor.run() + + return web.json_response(model.to_dict(), dumps=dumps) + + @routes.delete("/v2/models") + async def delete_model(request): + with create_session() as session: + model_path = request.query.get("path", None) + model_type = request.query.get("type", None) + if not model_path: + return missing_field("path") + if not model_type: + return missing_field("type") + + full_path = get_full_path(model_type, model_path) + if full_path: + return bad_request("Model file exists, please delete the file before deleting the model record.") + + model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first() + if not model: + return not_found("Model") + session.delete(model) + session.commit() + return web.Response(status=204) + + @routes.get("/v2/tags") + async def get_tags(request): + with create_session() as session: + tags = session.query(Tag).all() + return web.json_response( + [{"id": tag.id, "name": tag.name} for tag in tags] + ) + + @routes.post("/v2/tags") + async def create_tag(request): + with create_session() as session: + data = await request.json() + name = data.get("name", None) + if not name: + return missing_field("name") + tag = Tag(name=name) + session.add(tag) + session.commit() + return web.json_response({"id": tag.id, "name": tag.name}) + + @routes.delete("/v2/tags") + async def delete_tag(request): + with create_session() as session: + tag_id = request.query.get("id", None) + if not tag_id: + return missing_field("id") + tag = session.query(Tag).filter(Tag.id == tag_id).first() + if not tag: + return not_found("Tag") + session.delete(tag) + session.commit() + return web.Response(status=204) + + @routes.post("/v2/models/tags") + async def add_model_tag(request): + with create_session() as session: + data = await request.json() + tag_id = data.get("tag", None) + model_path = data.get("path", None) + model_type = data.get("type", None) + + if tag_id is None: + return missing_field("tag") + if model_path is None: + return missing_field("path") + if model_type is None: + return missing_field("type") + + try: + tag_id = int(tag_id) + except ValueError: + return bad_request("Invalid tag id") + + tag = session.query(Tag).filter(Tag.id == tag_id).first() + model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first() + if not model: + return not_found("Model") + model.tags.append(tag) + session.commit() + return web.json_response(model.to_dict(), dumps=dumps) + + @routes.delete("/v2/models/tags") + async def delete_model_tag(request): + with create_session() as session: + tag_id = request.query.get("tag", None) + model_path = request.query.get("path", None) + model_type = request.query.get("type", None) + + if tag_id is None: + return missing_field("tag") + if model_path is None: + return missing_field("path") + if model_type is None: + return missing_field("type") + + try: + tag_id = int(tag_id) + except ValueError: + return bad_request("Invalid tag id") + + model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first() + if not model: + return not_found("Model") + model.tags = [tag for tag in model.tags if tag.id != tag_id] + session.commit() + return web.Response(status=204) + + + + @routes.get("/v2/models/missing") + async def get_missing_models(request): + return web.json_response(model_processor.missing_models) + def get_model_file_list(self, folder_name: str): folder_name = map_legacy(folder_name) folders = folder_paths.folder_names_and_paths[folder_name] @@ -146,39 +334,5 @@ class ModelFileManager: return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() - def get_model_previews(self, filepath: str) -> list[str | BytesIO]: - dirname = os.path.dirname(filepath) - - if not os.path.exists(dirname): - return [] - - basename = os.path.splitext(filepath)[0] - match_files = glob.glob(f"{basename}.*", recursive=False) - image_files = filter_files_content_types(match_files, "image") - safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) - safetensors_metadata = {} - - result: list[str | BytesIO] = [] - - for filename in image_files: - _basename = os.path.splitext(filename)[0] - if _basename == basename: - result.append(filename) - if _basename == f"{basename}.preview": - result.append(filename) - - if safetensors_file: - safetensors_filepath = os.path.join(dirname, safetensors_file) - header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) - if header: - safetensors_metadata = json.loads(header) - safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) - if safetensors_images: - safetensors_images = json.loads(safetensors_images) - for image in safetensors_images: - result.append(BytesIO(base64.b64decode(image))) - - return result - def __exit__(self, exc_type, exc_value, traceback): self.clear_cache() diff --git a/app/model_processor.py b/app/model_processor.py new file mode 100644 index 000000000..782ef474f --- /dev/null +++ b/app/model_processor.py @@ -0,0 +1,263 @@ +import base64 +from datetime import datetime +import glob +import hashlib +from io import BytesIO +import json +import logging +import os +import threading +import time +import comfy.utils +from app.database.models import Model +from app.database.db import create_session +from comfy.cli_args import args +from folder_paths import ( + filter_files_content_types, + get_full_path, + folder_names_and_paths, + get_filename_list, +) +from PIL import Image +from urllib import request + + +def get_model_previews( + filepath: str, check_metadata: bool = True +) -> list[str | BytesIO]: + dirname = os.path.dirname(filepath) + + if not os.path.exists(dirname): + return [] + + basename = os.path.splitext(filepath)[0] + match_files = glob.glob(f"{basename}.*", recursive=False) + image_files = filter_files_content_types(match_files, "image") + + result: list[str | BytesIO] = [] + + for filename in image_files: + _basename = os.path.splitext(filename)[0] + if _basename == basename: + result.append(filename) + if _basename == f"{basename}.preview": + result.append(filename) + + if not check_metadata: + return result + + safetensors_file = next( + filter(lambda x: x.endswith(".safetensors"), match_files), None + ) + safetensors_metadata = {} + + if safetensors_file: + safetensors_filepath = os.path.join(dirname, safetensors_file) + header = comfy.utils.safetensors_header( + safetensors_filepath, max_size=8 * 1024 * 1024 + ) + if header: + safetensors_metadata = json.loads(header) + safetensors_images = safetensors_metadata.get("__metadata__", {}).get( + "ssmd_cover_images", None + ) + if safetensors_images: + safetensors_images = json.loads(safetensors_images) + for image in safetensors_images: + result.append(BytesIO(base64.b64decode(image))) + + return result + + +class ModelProcessor: + def __init__(self): + self._thread = None + self._lock = threading.Lock() + self._run = False + self.missing_models = [] + + def run(self): + if args.disable_model_processing: + return + + if self._thread is None: + # Lock to prevent multiple threads from starting + with self._lock: + self._run = True + if self._thread is None: + self._thread = threading.Thread(target=self._process_models) + self._thread.daemon = True + self._thread.start() + + def populate_models(self, session): + # Ensure database state matches filesystem + + existing_models = session.query(Model).all() + + for folder_name in folder_names_and_paths.keys(): + if folder_name == "custom_nodes" or folder_name == "configs": + continue + seen = set() + files = get_filename_list(folder_name) + + for file in files: + if file in seen: + logging.warning(f"Skipping duplicate named model: {file}") + continue + seen.add(file) + + existing_model = None + for model in existing_models: + if model.path == file and model.type == folder_name: + existing_model = model + break + + if existing_model: + # Model already exists in db, remove from list and skip + existing_models.remove(existing_model) + continue + + file_path = get_full_path(folder_name, file) + + model = Model( + path=file, + type=folder_name, + date_added=datetime.fromtimestamp(os.path.getctime(file_path)), + ) + session.add(model) + + for model in existing_models: + if not get_full_path(model.type, model.path): + logging.warning(f"Model {model.path} not found") + self.missing_models.append({"type": model.type, "path": model.path}) + + session.commit() + + def _get_models(self, session): + models = session.query(Model).filter(Model.hash == None).all() + return models + + def _process_file(self, model_path): + is_safetensors = model_path.endswith(".safetensors") + metadata = {} + h = hashlib.sha256() + + with open(model_path, "rb", buffering=0) as f: + if is_safetensors: + # Read header length (8 bytes) + header_size_bytes = f.read(8) + header_len = int.from_bytes(header_size_bytes, "little") + h.update(header_size_bytes) + + # Read header + header_bytes = f.read(header_len) + h.update(header_bytes) + try: + metadata = json.loads(header_bytes) + except json.JSONDecodeError: + pass + + # Read rest of file + b = bytearray(128 * 1024) + mv = memoryview(b) + while n := f.readinto(mv): + h.update(mv[:n]) + + return h.hexdigest(), metadata + + def _populate_info(self, model, metadata): + model.title = metadata.get("modelspec.title", None) + model.description = metadata.get("modelspec.description", None) + model.architecture = metadata.get("modelspec.architecture", None) + + def _extract_image(self, model_path, metadata): + # check if image already exists + if len(get_model_previews(model_path, check_metadata=False)) > 0: + return + + image_path = os.path.splitext(model_path)[0] + ".webp" + if os.path.exists(image_path): + return + + cover_images = metadata.get("ssmd_cover_images", None) + image = None + if cover_images: + try: + cover_images = json.loads(cover_images) + if len(cover_images) > 0: + image_data = cover_images[0] + image = Image.open(BytesIO(base64.b64decode(image_data))) + except Exception as e: + logging.warning( + f"Error extracting cover image for model {model_path}: {e}" + ) + + if not image: + thumbnail = metadata.get("modelspec.thumbnail", None) + if thumbnail: + try: + response = request.urlopen(thumbnail) + image = Image.open(response) + except Exception as e: + logging.warning( + f"Error extracting thumbnail for model {model_path}: {e}" + ) + + if image: + image.thumbnail((512, 512)) + image.save(image_path) + image.close() + + def _process_models(self): + with create_session() as session: + checked = set() + self.populate_models(session) + + while self._run: + self._run = False + + models = self._get_models(session) + + if len(models) == 0: + break + + for model in models: + # prevent looping on the same model if it crashes + if model.path in checked: + continue + + checked.add(model.path) + + try: + time.sleep(0) + now = time.time() + model_path = get_full_path(model.type, model.path) + + if not model_path: + logging.warning(f"Model {model.path} not found") + self.missing_models.append(model.path) + continue + + logging.debug(f"Processing model {model_path}") + hash, header = self._process_file(model_path) + logging.debug( + f"Processed model {model_path} in {time.time() - now} seconds" + ) + model.hash = hash + + if header: + metadata = header.get("__metadata__", None) + + if metadata: + self._populate_info(model, metadata) + self._extract_image(model_path, metadata) + + session.commit() + except Exception as e: + logging.error(f"Error processing model {model.path}: {e}") + + with self._lock: + self._thread = None + + +model_processor = ModelProcessor() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 812798bf8..570de28c9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -178,6 +178,12 @@ parser.add_argument( parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.") +database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") +) +parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/main.py b/main.py index c5c9f4e09..0ce17192a 100644 --- a/main.py +++ b/main.py @@ -138,6 +138,8 @@ import server from server import BinaryEventTypes import nodes import comfy.model_management +from app.database.db import init_db +from app.model_processor import model_processor def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() @@ -262,6 +264,7 @@ def start_comfyui(asyncio_loop=None): cuda_malloc_warning() + init_db() prompt_server.add_routes() hijack_progress(prompt_server) @@ -269,6 +272,9 @@ def start_comfyui(asyncio_loop=None): if args.quick_test_for_ci: exit(0) + + # Scan for changed model files and update db + model_processor.run() os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) call_on_start = None diff --git a/requirements.txt b/requirements.txt index 4c2c0b2b2..9c8fb3585 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,8 @@ Pillow scipy tqdm psutil +alembic +SQLAlchemy #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py index ae59206f6..1da1bd590 100644 --- a/tests-unit/app_test/model_manager_test.py +++ b/tests-unit/app_test/model_manager_test.py @@ -7,11 +7,33 @@ from PIL import Image from aiohttp import web from unittest.mock import patch from app.model_manager import ModelFileManager +from app.database.models import Base, Model, Tag +from comfy.cli_args import args +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker pytestmark = ( pytest.mark.asyncio ) # This applies the asyncio mark to all test functions in the module +@pytest.fixture +def session(): + # Configure in-memory database + args.database_url = "sqlite:///:memory:" + + # Create engine and session factory + engine = create_engine(args.database_url) + Session = sessionmaker(bind=engine) + + # Create all tables + Base.metadata.create_all(engine) + + # Patch Session factory + with patch('app.database.db.Session', Session): + yield Session() + + Base.metadata.drop_all(engine) + @pytest.fixture def model_manager(): return ModelFileManager() @@ -60,3 +82,287 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path): # Clean up img.close() + +async def test_get_models(aiohttp_client, app, session): + tag = Tag(name='test_tag') + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + model.tags.append(tag) + session.add(tag) + session.add(model) + session.commit() + + client = await aiohttp_client(app) + resp = await client.get('/v2/models') + assert resp.status == 200 + data = await resp.json() + assert len(data) == 1 + assert data[0]['path'] == 'model1.safetensors' + assert len(data[0]['tags']) == 1 + assert data[0]['tags'][0]['name'] == 'test_tag' + +async def test_add_model(aiohttp_client, app, session): + tag = Tag(name='test_tag') + session.add(tag) + session.commit() + tag_id = tag.id + + with patch('app.model_manager.model_processor') as mock_processor: + with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={ + 'type': 'checkpoints', + 'path': 'model1.safetensors', + 'title': 'Test Model', + 'tags': [tag_id] + }) + + assert resp.status == 200 + data = await resp.json() + assert data['path'] == 'model1.safetensors' + assert len(data['tags']) == 1 + assert data['tags'][0]['name'] == 'test_tag' + + # Ensure that models are re-processed after adding + mock_processor.run.assert_called_once() + +async def test_delete_model(aiohttp_client, app, session): + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + session.add(model) + session.commit() + + with patch('app.model_manager.get_full_path', return_value=None): + client = await aiohttp_client(app) + resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors') + assert resp.status == 204 + + # Verify model was deleted + model = session.query(Model).first() + assert model is None + +async def test_delete_model_file_exists(aiohttp_client, app, session): + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + session.add(model) + session.commit() + + with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'): + client = await aiohttp_client(app) + resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors') + assert resp.status == 400 + + data = await resp.json() + assert "file exists" in data["error"].lower() + + # Verify model was not deleted + model = session.query(Model).first() + assert model is not None + assert model.path == 'model1.safetensors' + +async def test_get_tags(aiohttp_client, app, session): + tags = [Tag(name='tag1'), Tag(name='tag2')] + for tag in tags: + session.add(tag) + session.commit() + + client = await aiohttp_client(app) + resp = await client.get('/v2/tags') + assert resp.status == 200 + data = await resp.json() + assert len(data) == 2 + assert {t['name'] for t in data} == {'tag1', 'tag2'} + +async def test_create_tag(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.post('/v2/tags', json={'name': 'new_tag'}) + assert resp.status == 200 + data = await resp.json() + assert data['name'] == 'new_tag' + + # Verify tag was created + tag = session.query(Tag).first() + assert tag.name == 'new_tag' + +async def test_delete_tag(aiohttp_client, app, session): + tag = Tag(name='test_tag') + session.add(tag) + session.commit() + tag_id = tag.id + + client = await aiohttp_client(app) + resp = await client.delete(f'/v2/tags?id={tag_id}') + assert resp.status == 204 + + # Verify tag was deleted + tag = session.query(Tag).first() + assert tag is None + +async def test_add_model_tag(aiohttp_client, app, session): + tag = Tag(name='test_tag') + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + session.add(tag) + session.add(model) + session.commit() + tag_id = tag.id + + client = await aiohttp_client(app) + resp = await client.post('/v2/models/tags', json={ + 'tag': tag_id, + 'type': 'checkpoints', + 'path': 'model1.safetensors' + }) + assert resp.status == 200 + data = await resp.json() + assert len(data['tags']) == 1 + assert data['tags'][0]['name'] == 'test_tag' + +async def test_delete_model_tag(aiohttp_client, app, session): + tag = Tag(name='test_tag') + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + model.tags.append(tag) + session.add(tag) + session.add(model) + session.commit() + tag_id = tag.id + + client = await aiohttp_client(app) + resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors') + assert resp.status == 204 + + # Verify tag was removed + model = session.query(Model).first() + assert len(model.tags) == 0 + +async def test_add_model_duplicate(aiohttp_client, app, session): + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + session.add(model) + session.commit() + + with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={ + 'type': 'checkpoints', + 'path': 'model1.safetensors', + 'title': 'Duplicate Model' + }) + assert resp.status == 400 + +async def test_add_model_missing_fields(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={}) + assert resp.status == 400 + +async def test_add_tag_missing_name(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.post('/v2/tags', json={}) + assert resp.status == 400 + +async def test_delete_model_not_found(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors') + assert resp.status == 404 + +async def test_delete_tag_not_found(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.delete('/v2/tags?id=999') + assert resp.status == 404 + +async def test_add_model_missing_path(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={ + 'type': 'checkpoints', + 'title': 'Test Model' + }) + assert resp.status == 400 + data = await resp.json() + assert "path" in data["error"].lower() + +async def test_add_model_invalid_field(aiohttp_client, app, session): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={ + 'type': 'checkpoints', + 'path': 'model1.safetensors', + 'invalid_field': 'some value' + }) + assert resp.status == 400 + data = await resp.json() + assert "invalid field" in data["error"].lower() + +async def test_add_model_nonexistent_file(aiohttp_client, app, session): + with patch('app.model_manager.get_full_path', return_value=None): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={ + 'type': 'checkpoints', + 'path': 'nonexistent.safetensors' + }) + assert resp.status == 404 + data = await resp.json() + assert "file" in data["error"].lower() + +async def test_add_model_invalid_tag(aiohttp_client, app, session): + with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'): + client = await aiohttp_client(app) + resp = await client.post('/v2/models', json={ + 'type': 'checkpoints', + 'path': 'model1.safetensors', + 'tags': [999] # Non-existent tag ID + }) + assert resp.status == 404 + data = await resp.json() + assert "tag" in data["error"].lower() + +async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session): + # Create a tag but no model + tag = Tag(name='test_tag') + session.add(tag) + session.commit() + tag_id = tag.id + + client = await aiohttp_client(app) + resp = await client.post('/v2/models/tags', json={ + 'tag': tag_id, + 'type': 'checkpoints', + 'path': 'nonexistent.safetensors' + }) + assert resp.status == 404 + data = await resp.json() + assert "model" in data["error"].lower() + +async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session): + # Create a model first + model = Model( + type='checkpoints', + path='model1.safetensors', + title='Test Model' + ) + session.add(model) + session.commit() + + client = await aiohttp_client(app) + resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors') + assert resp.status == 400 + data = await resp.json() + assert "invalid tag id" in data["error"].lower() + diff --git a/utils/web.py b/utils/web.py new file mode 100644 index 000000000..a2deb1846 --- /dev/null +++ b/utils/web.py @@ -0,0 +1,12 @@ +import json +from datetime import datetime + + +class DateTimeEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +dumps = DateTimeEncoder().encode