mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 12:35:30 +08:00
332 lines
12 KiB
Python
332 lines
12 KiB
Python
import os
|
|
import logging
|
|
import time
|
|
|
|
import requests
|
|
from tqdm import tqdm
|
|
from folder_paths import get_relative_path, get_full_path
|
|
from app.database.db import create_session, dependencies_available, can_create_session
|
|
import blake3
|
|
import comfy.utils
|
|
|
|
|
|
if dependencies_available():
|
|
from app.database.models import Model
|
|
|
|
|
|
class ModelProcessor:
|
|
def _validate_path(self, model_path):
|
|
try:
|
|
if not self._file_exists(model_path):
|
|
logging.error(f"Model file not found: {model_path}")
|
|
return None
|
|
|
|
result = get_relative_path(model_path)
|
|
if not result:
|
|
logging.error(
|
|
f"Model file not in a recognized model directory: {model_path}"
|
|
)
|
|
return None
|
|
|
|
return result
|
|
except Exception as e:
|
|
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
|
return None
|
|
|
|
def _file_exists(self, path):
|
|
"""Check if a file exists."""
|
|
return os.path.exists(path)
|
|
|
|
def _get_file_size(self, path):
|
|
"""Get file size."""
|
|
return os.path.getsize(path)
|
|
|
|
def _get_hasher(self):
|
|
return blake3.blake3()
|
|
|
|
def _hash_file(self, model_path):
|
|
try:
|
|
hasher = self._get_hasher()
|
|
with open(model_path, "rb", buffering=0) as f:
|
|
b = bytearray(128 * 1024)
|
|
mv = memoryview(b)
|
|
while n := f.readinto(mv):
|
|
hasher.update(mv[:n])
|
|
return hasher.hexdigest()
|
|
except Exception as e:
|
|
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
|
return None
|
|
|
|
def _get_existing_model(self, session, model_type, model_relative_path):
|
|
return (
|
|
session.query(Model)
|
|
.filter(Model.type == model_type)
|
|
.filter(Model.path == model_relative_path)
|
|
.first()
|
|
)
|
|
|
|
def _ensure_source_url(self, session, model, source_url):
|
|
if model.source_url is None:
|
|
model.source_url = source_url
|
|
session.commit()
|
|
|
|
def _update_database(
|
|
self,
|
|
session,
|
|
model_type,
|
|
model_path,
|
|
model_relative_path,
|
|
model_hash,
|
|
model,
|
|
source_url,
|
|
):
|
|
try:
|
|
if not model:
|
|
model = self._get_existing_model(
|
|
session, model_type, model_relative_path
|
|
)
|
|
|
|
if not model:
|
|
model = Model(
|
|
path=model_relative_path,
|
|
type=model_type,
|
|
file_name=os.path.basename(model_path),
|
|
)
|
|
session.add(model)
|
|
|
|
model.file_size = self._get_file_size(model_path)
|
|
model.hash = model_hash
|
|
if model_hash:
|
|
model.hash_algorithm = "blake3"
|
|
model.source_url = source_url
|
|
|
|
session.commit()
|
|
return model
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error updating database for {model_relative_path}: {str(e)}"
|
|
)
|
|
|
|
def process_file(self, model_path, source_url=None, model_hash=None):
|
|
"""
|
|
Process a model file and update the database with metadata.
|
|
If the file already exists and matches the database, it will not be processed again.
|
|
Returns the model object or if an error occurs, returns None.
|
|
"""
|
|
try:
|
|
if not can_create_session():
|
|
return
|
|
|
|
result = self._validate_path(model_path)
|
|
if not result:
|
|
return
|
|
model_type, model_relative_path = result
|
|
|
|
with create_session() as session:
|
|
session.expire_on_commit = False
|
|
|
|
existing_model = self._get_existing_model(
|
|
session, model_type, model_relative_path
|
|
)
|
|
if (
|
|
existing_model
|
|
and existing_model.hash
|
|
and existing_model.file_size == self._get_file_size(model_path)
|
|
):
|
|
# File exists with hash and same size, no need to process
|
|
self._ensure_source_url(session, existing_model, source_url)
|
|
return existing_model
|
|
|
|
if model_hash:
|
|
model_hash = model_hash.lower()
|
|
logging.info(f"Using provided hash: {model_hash}")
|
|
else:
|
|
start_time = time.time()
|
|
logging.info(f"Hashing model {model_relative_path}")
|
|
model_hash = self._hash_file(model_path)
|
|
if not model_hash:
|
|
return
|
|
logging.info(
|
|
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
|
)
|
|
|
|
return self._update_database(
|
|
session,
|
|
model_type,
|
|
model_path,
|
|
model_relative_path,
|
|
model_hash,
|
|
existing_model,
|
|
source_url,
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
|
return None
|
|
|
|
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
|
"""
|
|
Retrieve a model file from the database by hash and optionally by model type.
|
|
Returns the model object or None if the model doesnt exist or an error occurs.
|
|
"""
|
|
try:
|
|
if not can_create_session():
|
|
return
|
|
|
|
dispose_session = False
|
|
|
|
if session is None:
|
|
session = create_session()
|
|
dispose_session = True
|
|
|
|
model = session.query(Model).filter(Model.hash == model_hash)
|
|
if model_type is not None:
|
|
model = model.filter(Model.type == model_type)
|
|
return model.first()
|
|
except Exception as e:
|
|
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
|
return None
|
|
finally:
|
|
if dispose_session:
|
|
session.close()
|
|
|
|
def retrieve_hash(self, model_path, model_type=None):
|
|
"""
|
|
Retrieve the hash of a model file from the database.
|
|
Returns the hash or None if the model doesnt exist or an error occurs.
|
|
"""
|
|
try:
|
|
if not can_create_session():
|
|
return
|
|
|
|
if model_type is not None:
|
|
result = self._validate_path(model_path)
|
|
if not result:
|
|
return None
|
|
model_type, model_relative_path = result
|
|
|
|
with create_session() as session:
|
|
model = self._get_existing_model(
|
|
session, model_type, model_relative_path
|
|
)
|
|
if model and model.hash:
|
|
return model.hash
|
|
return None
|
|
except Exception as e:
|
|
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
|
return None
|
|
|
|
def _validate_file_extension(self, file_name):
|
|
"""Validate that the file extension is supported."""
|
|
extension = os.path.splitext(file_name)[1]
|
|
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
|
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
|
|
|
def _check_existing_file(self, model_type, file_name, expected_hash):
|
|
"""Check if file exists and has correct hash."""
|
|
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
|
if self._file_exists(destination_path):
|
|
model = self.process_file(destination_path)
|
|
if model and (expected_hash is None or model.hash == expected_hash):
|
|
logging.debug(
|
|
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
|
)
|
|
return destination_path
|
|
else:
|
|
raise ValueError(
|
|
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
|
)
|
|
return None
|
|
|
|
def _check_existing_file_by_hash(self, hash, type, url):
|
|
"""Check if a file with the given hash exists in the database and on disk."""
|
|
hash = hash.lower()
|
|
with create_session() as session:
|
|
model = self.retrieve_model_by_hash(hash, type, session)
|
|
if model:
|
|
existing_path = get_full_path(type, model.path)
|
|
if existing_path:
|
|
logging.debug(
|
|
f"File {model.path} already exists in the database at {existing_path}"
|
|
)
|
|
self._ensure_source_url(session, model, url)
|
|
return existing_path
|
|
else:
|
|
logging.debug(
|
|
f"File {model.path} exists in the database but not on disk"
|
|
)
|
|
return None
|
|
|
|
def _download_file(self, url, destination_path, hasher):
|
|
"""Download a file and update the hasher with its contents."""
|
|
response = requests.get(url, stream=True)
|
|
logging.info(f"Downloading {url} to {destination_path}")
|
|
|
|
with open(destination_path, "wb") as f:
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
if total_size > 0:
|
|
pbar = comfy.utils.ProgressBar(total_size)
|
|
else:
|
|
pbar = None
|
|
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
|
for chunk in response.iter_content(chunk_size=128 * 1024):
|
|
if chunk:
|
|
f.write(chunk)
|
|
hasher.update(chunk)
|
|
progress_bar.update(len(chunk))
|
|
if pbar:
|
|
pbar.update(len(chunk))
|
|
|
|
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
|
"""Verify that the downloaded file has the expected hash."""
|
|
if expected_hash is not None and calculated_hash != expected_hash:
|
|
self._remove_file(destination_path)
|
|
raise ValueError(
|
|
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
|
)
|
|
|
|
def _remove_file(self, file_path):
|
|
"""Remove a file from disk."""
|
|
os.remove(file_path)
|
|
|
|
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
|
"""
|
|
Ensure a model file is downloaded and has the correct hash.
|
|
Returns the path to the downloaded file.
|
|
"""
|
|
logging.debug(
|
|
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
|
)
|
|
|
|
# Validate file extension
|
|
self._validate_file_extension(desired_file_name)
|
|
|
|
# Check if file exists with correct hash
|
|
if hash:
|
|
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
|
if existing_path:
|
|
return existing_path
|
|
|
|
# Check if file exists locally
|
|
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
|
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
|
if existing_path:
|
|
return existing_path
|
|
|
|
# Download the file
|
|
hasher = self._get_hasher()
|
|
self._download_file(url, destination_path, hasher)
|
|
|
|
# Verify hash
|
|
calculated_hash = hasher.hexdigest()
|
|
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
|
|
|
# Update database
|
|
self.process_file(destination_path, url, calculated_hash)
|
|
|
|
# TODO: Notify frontend to reload models
|
|
|
|
return destination_path
|
|
|
|
|
|
model_processor = ModelProcessor()
|