This commit is contained in:
pythongosssss 2025-06-01 15:45:15 +01:00
parent 7f7b3f1695
commit 7d5160f92c
5 changed files with 13 additions and 15 deletions

View File

@ -1,7 +1,7 @@
"""init """init
Revision ID: e9c714da8d57 Revision ID: e9c714da8d57
Revises: Revises:
Create Date: 2025-05-30 20:14:33.772039 Create Date: 2025-05-30 20:14:33.772039
""" """
@ -20,7 +20,6 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('model', op.create_table('model',
sa.Column('type', sa.Text(), nullable=False), sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False), sa.Column('path', sa.Text(), nullable=False),
@ -32,7 +31,6 @@ def upgrade() -> None:
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path') sa.PrimaryKeyConstraint('type', 'path')
) )
# ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:

View File

@ -64,7 +64,7 @@ class ModelProcessor:
.filter(Model.path == model_relative_path) .filter(Model.path == model_relative_path)
.first() .first()
) )
def _ensure_source_url(self, session, model, source_url): def _ensure_source_url(self, session, model, source_url):
if model.source_url is None: if model.source_url is None:
model.source_url = source_url model.source_url = source_url
@ -171,9 +171,9 @@ class ModelProcessor:
try: try:
if not can_create_session(): if not can_create_session():
return return
dispose_session = False dispose_session = False
if session is None: if session is None:
session = create_session() session = create_session()
dispose_session = True dispose_session = True

View File

@ -204,7 +204,7 @@ parser.add_argument(
) )
database_default_path = os.path.abspath( database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")

View File

@ -303,23 +303,23 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
def get_relative_path(full_path: str) -> tuple[str, str] | None: def get_relative_path(full_path: str) -> tuple[str, str] | None:
"""Convert a full path back to a type-relative path. """Convert a full path back to a type-relative path.
Args: Args:
full_path: The full path to the file full_path: The full path to the file
Returns: Returns:
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
""" """
global folder_names_and_paths global folder_names_and_paths
full_path = os.path.normpath(full_path) full_path = os.path.normpath(full_path)
for model_type, (paths, _) in folder_names_and_paths.items(): for model_type, (paths, _) in folder_names_and_paths.items():
for base_path in paths: for base_path in paths:
base_path = os.path.normpath(base_path) base_path = os.path.normpath(base_path)
if full_path.startswith(base_path): if full_path.startswith(base_path):
relative_path = os.path.relpath(full_path, base_path) relative_path = os.path.relpath(full_path, base_path)
return model_type, relative_path return model_type, relative_path
return None return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:

View File

@ -195,7 +195,7 @@ def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
def test_process_file_without_hash(model_processor, db_session): def test_process_file_without_hash(model_processor, db_session):
# Test processing file without provided hash # Test processing file without provided hash
model_processor.file_exists[TEST_DESTINATION_PATH] = True model_processor.file_exists[TEST_DESTINATION_PATH] = True
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH): with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
result = model_processor.process_file(TEST_DESTINATION_PATH) result = model_processor.process_file(TEST_DESTINATION_PATH)
assert result is not None assert result is not None
@ -241,13 +241,13 @@ def test_validate_file_extension_valid_extensions(model_processor):
def test_process_file_existing_without_source_url(model_processor, db_session): def test_process_file_existing_without_source_url(model_processor, db_session):
# Test processing an existing file that needs its source URL updated # Test processing an existing file that needs its source URL updated
model_processor.file_exists[TEST_DESTINATION_PATH] = True model_processor.file_exists[TEST_DESTINATION_PATH] = True
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL) result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
assert result is not None assert result is not None
assert result.hash == TEST_EXPECTED_HASH assert result.hash == TEST_EXPECTED_HASH
assert result.source_url == TEST_URL assert result.source_url == TEST_URL
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
assert db_model.source_url == TEST_URL assert db_model.source_url == TEST_URL