From 1cb3c98947c36acc14103312c432805d46570a3c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 1 Jun 2025 15:32:02 +0100 Subject: [PATCH] Implement database & model hashing --- alembic.ini | 84 ++++++++++++++++ alembic_db/README.md | 4 + alembic_db/env.py | 69 +++++++++++++ alembic_db/script.py.mako | 28 ++++++ alembic_db/versions/565b08122d00_init.py | 34 +++++++ app/database/db.py | 90 +++++++++++++++++ app/database/models.py | 50 ++++++++++ app/frontend_management.py | 91 +++++++++++------ app/model_processor.py | 122 +++++++++++++++++++++++ comfy/cli_args.py | 6 ++ comfy/utils.py | 4 + folder_paths.py | 21 ++++ main.py | 8 +- requirements.txt | 2 + utils/install_util.py | 19 ++++ 15 files changed, 601 insertions(+), 31 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic_db/README.md create mode 100644 alembic_db/env.py create mode 100644 alembic_db/script.py.mako create mode 100644 alembic_db/versions/565b08122d00_init.py create mode 100644 app/database/db.py create mode 100644 app/database/models.py create mode 100644 app/model_processor.py create mode 100644 utils/install_util.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..12f18712f --- /dev/null +++ b/alembic.ini @@ -0,0 +1,84 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000..3b808c7ca --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,4 @@ +## Generate new revision + +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000..d278cfc53 --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,69 @@ +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000..480b130d6 --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic_db/versions/565b08122d00_init.py b/alembic_db/versions/565b08122d00_init.py new file mode 100644 index 000000000..9a8a51fbc --- /dev/null +++ b/alembic_db/versions/565b08122d00_init.py @@ -0,0 +1,34 @@ +"""init + +Revision ID: 565b08122d00 +Revises: +Create Date: 2025-05-29 19:15:56.230322 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '565b08122d00' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.create_table('model', + sa.Column('type', sa.Text(), nullable=False), + sa.Column('path', sa.Text(), nullable=False), + sa.Column('hash', sa.Text(), nullable=True), + sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.PrimaryKeyConstraint('type', 'path') + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_table('model') diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000..d17fa4f1f --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,90 @@ +import logging +import os +import shutil +from utils.install_util import get_missing_requirements_message +from comfy.cli_args import args + +Session = None + + +def can_create_session(): + return Session is not None + + +try: + import alembic + import sqlalchemy +except ImportError as e: + logging.error(get_missing_requirements_message()) + raise e + +from alembic import command +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if current_rev != target_rev: + # Backup the database pre upgrade + db_path = get_db_path() + backup_path = db_path + ".bkp" + if os.path.exists(db_path): + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.error(f"Error upgrading database: {e}") + raise e + + global Session + Session = sessionmaker(bind=engine) + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000..d2c1e042d --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,50 @@ +from sqlalchemy import ( + Column, + Text, + DateTime, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql import func + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + + +class Model(Base): + """ + SQLAlchemy model representing a model file in the system. + + This class defines the database schema for storing information about model files, + including their type, path, hash, and when they were added to the system. + + Attributes: + type (Text): The type of the model, this is the name of the folder in the models folder (primary key) + path (Text): The file path of the model relative to the type folder (primary key) + hash (Text): A sha256 hash of the model file + date_added (DateTime): Timestamp of when the model was added to the system + """ + + __tablename__ = "model" + + type = Column(Text, primary_key=True) + path = Column(Text, primary_key=True) + hash = Column(Text) + date_added = Column(DateTime, server_default=func.now()) + + def to_dict(self): + """ + Convert the model instance to a dictionary representation. + + Returns: + dict: A dictionary containing the attributes of the model + """ + dict = to_dict(self) + return dict diff --git a/app/frontend_management.py b/app/frontend_management.py index d9ef8c921..3e54e4d51 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -16,26 +16,15 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired +from utils.install_util import get_missing_requirements_message, requirements_path from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger -# The path to the requirements.txt file -req_path = Path(__file__).parents[1] / "requirements.txt" - - def frontend_install_warning_message(): - """The warning message to display when the frontend version is not up to date.""" - - extra = "" - if sys.flags.no_user_site: - extra = "-s " return f""" -Please install the updated requirements.txt file by running: -{sys.executable} {extra}-m pip install -r {req_path} +{get_missing_requirements_message()} This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. - -If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem """.strip() @@ -48,7 +37,7 @@ def check_frontend_version(): try: frontend_version_str = version("comfyui-frontend-package") frontend_version = parse_version(frontend_version_str) - with open(req_path, "r", encoding="utf-8") as f: + with open(requirements_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: app.logger.log_startup_warning( @@ -162,10 +151,30 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: + """ + A class to manage ComfyUI frontend versions and installations. + + This class handles the initialization and management of different frontend versions, + including the default frontend from the pip package and custom frontend versions + from GitHub repositories. + + Attributes: + CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored. + """ + CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") @classmethod def default_frontend_path(cls) -> str: + """ + Get the path to the default frontend installation from the pip package. + + Returns: + str: The path to the default frontend static files. + + Raises: + SystemExit: If the comfyui-frontend-package is not installed. + """ try: import comfyui_frontend_package @@ -186,6 +195,15 @@ comfyui-frontend-package is not installed. @classmethod def templates_path(cls) -> str: + """ + Get the path to the workflow templates. + + Returns: + str: The path to the workflow templates directory. + + Raises: + SystemExit: If the comfyui-workflow-templates package is not installed. + """ try: import comfyui_workflow_templates @@ -221,12 +239,17 @@ comfyui-workflow-templates is not installed. @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ + Parse a version string into its components. + + The version string should be in the format: 'owner/repo@version' + where version can be either a semantic version (v1.2.3) or 'latest'. + Args: value (str): The version string to parse. - + Returns: - tuple[str, str]: A tuple containing provider name and version. - + tuple[str, str, str]: A tuple containing (owner, repo, version). + Raises: argparse.ArgumentTypeError: If the version string is invalid. """ @@ -242,18 +265,22 @@ comfyui-workflow-templates is not installed. cls, version_string: str, provider: Optional[FrontEndProvider] = None ) -> str: """ - Initializes the frontend for the specified version. - + Initialize a frontend version without error handling. + + This method attempts to initialize a specific frontend version, either from + the default pip package or from a custom GitHub repository. It will download + and extract the frontend files if necessary. + Args: - version_string (str): The version string. - provider (FrontEndProvider, optional): The provider to use. Defaults to None. - + version_string (str): The version string specifying which frontend to use. + provider (FrontEndProvider, optional): The provider to use for custom frontends. + Returns: str: The path to the initialized frontend. - + Raises: - Exception: If there is an error during the initialization process. - main error source might be request timeout or invalid URL. + Exception: If there is an error during initialization (e.g., network timeout, + invalid URL, or missing assets). """ if version_string == DEFAULT_VERSION_STRING: check_frontend_version() @@ -305,13 +332,17 @@ comfyui-workflow-templates is not installed. @classmethod def init_frontend(cls, version_string: str) -> str: """ - Initializes the frontend with the specified version string. - + Initialize a frontend version with error handling. + + This is the main method to initialize a frontend version. It wraps init_frontend_unsafe + with error handling, falling back to the default frontend if initialization fails. + Args: - version_string (str): The version string to initialize the frontend with. - + version_string (str): The version string specifying which frontend to use. + Returns: - str: The path of the initialized frontend. + str: The path to the initialized frontend. If initialization fails, + returns the path to the default frontend. """ try: return cls.init_frontend_unsafe(version_string) diff --git a/app/model_processor.py b/app/model_processor.py new file mode 100644 index 000000000..980940262 --- /dev/null +++ b/app/model_processor.py @@ -0,0 +1,122 @@ +import hashlib +import os +import logging +import time +from app.database.models import Model +from app.database.db import create_session +from folder_paths import get_relative_path + + +class ModelProcessor: + def _validate_path(self, model_path): + try: + if not os.path.exists(model_path): + logging.error(f"Model file not found: {model_path}") + return None + + result = get_relative_path(model_path) + if not result: + logging.error( + f"Model file not in a recognized model directory: {model_path}" + ) + return None + + return result + except Exception as e: + logging.error(f"Error validating model path {model_path}: {str(e)}") + return None + + def _hash_file(self, model_path): + try: + h = hashlib.sha256() + with open(model_path, "rb", buffering=0) as f: + b = bytearray(128 * 1024) + mv = memoryview(b) + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() + except Exception as e: + logging.error(f"Error hashing file {model_path}: {str(e)}") + return None + + def _get_existing_model(self, session, model_type, model_relative_path): + return ( + session.query(Model) + .filter(Model.type == model_type) + .filter(Model.path == model_relative_path) + .first() + ) + + def _update_database( + self, session, model_type, model_relative_path, model_hash, model=None + ): + try: + if not model: + model = self._get_existing_model( + session, model_type, model_relative_path + ) + + if not model: + model = Model( + path=model_relative_path, + type=model_type, + ) + session.add(model) + + model.hash = model_hash + session.commit() + return model + except Exception as e: + logging.error( + f"Error updating database for {model_relative_path}: {str(e)}" + ) + + def process_file(self, model_path): + try: + result = self._validate_path(model_path) + if not result: + return + model_type, model_relative_path = result + + with create_session() as session: + existing_model = self._get_existing_model( + session, model_type, model_relative_path + ) + if existing_model and existing_model.hash: + # File exists with hash, no need to process + return existing_model + + start_time = time.time() + logging.info(f"Hashing model {model_relative_path}") + model_hash = self._hash_file(model_path) + if not model_hash: + return + logging.info( + f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" + ) + + return self._update_database(session, model_type, model_relative_path, model_hash) + except Exception as e: + logging.error(f"Error processing model file {model_path}: {str(e)}") + + def retrieve_hash(self, model_path, model_type=None): + try: + if model_type is not None: + result = self._validate_path(model_path) + if not result: + return None + model_type, model_relative_path = result + + with create_session() as session: + model = self._get_existing_model( + session, model_type, model_relative_path + ) + if model and model.hash: + return model.hash + return None + except Exception as e: + logging.error(f"Error retrieving hash for {model_path}: {str(e)}") + return None + + +model_processor = ModelProcessor() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 4fb675f99..154491fe0 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -203,6 +203,12 @@ parser.add_argument( help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) +database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") +) +parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/utils.py b/comfy/utils.py index 1f8d71292..547ce9fc9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,6 +20,7 @@ import torch import math import struct +from app.model_processor import model_processor import comfy.checkpoint_pickle import safetensors.torch import numpy as np @@ -53,6 +54,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: device = torch.device("cpu") metadata = None + + model_processor.process_file(ckpt) + if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: diff --git a/folder_paths.py b/folder_paths.py index f0b3fd103..452409bf0 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -299,6 +299,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str: return full_path +def get_relative_path(full_path: str) -> tuple[str, str] | None: + """Convert a full path back to a type-relative path. + + Args: + full_path: The full path to the file + + Returns: + tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise + """ + global folder_names_and_paths + full_path = os.path.normpath(full_path) + + for model_type, (paths, _) in folder_names_and_paths.items(): + for base_path in paths: + base_path = os.path.normpath(base_path) + if full_path.startswith(base_path): + relative_path = os.path.relpath(full_path, base_path) + return model_type, relative_path + + return None + def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: folder_name = map_legacy(folder_name) global folder_names_and_paths diff --git a/main.py b/main.py index fb1f8d20b..d6f8193c4 100644 --- a/main.py +++ b/main.py @@ -147,7 +147,6 @@ def cuda_malloc_warning(): if cuda_malloc_warning: logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") - def prompt_worker(q, server_instance): current_time: float = 0.0 cache_type = execution.CacheType.CLASSIC @@ -237,6 +236,12 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) +def setup_database(): + try: + from app.database.db import init_db + init_db() + except Exception as e: + logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}") def start_comfyui(asyncio_loop=None): """ @@ -266,6 +271,7 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() + setup_database() prompt_server.add_routes() hijack_progress(prompt_server) diff --git a/requirements.txt b/requirements.txt index b98dc1268..ea51f24ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,8 @@ Pillow scipy tqdm psutil +alembic +SQLAlchemy #non essential dependencies: kornia>=0.7.1 diff --git a/utils/install_util.py b/utils/install_util.py new file mode 100644 index 000000000..5e6d51a2d --- /dev/null +++ b/utils/install_util.py @@ -0,0 +1,19 @@ +from pathlib import Path +import sys + +# The path to the requirements.txt file +requirements_path = Path(__file__).parents[1] / "requirements.txt" + + +def get_missing_requirements_message(): + """The warning message to display when a package is missing.""" + + extra = "" + if sys.flags.no_user_site: + extra = "-s " + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {requirements_path} + +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. +""".strip()