ComfyUI/model_filemanager/download_models.py

187 lines
7.6 KiB
Python
Raw Normal View History

2024-08-08 19:53:17 +00:00
from __future__ import annotations
2024-08-07 01:07:32 +00:00
import aiohttp
import os
2024-08-07 19:12:01 +00:00
import traceback
import logging
2024-08-07 01:07:32 +00:00
from folder_paths import models_dir
import re
2024-08-08 19:53:17 +00:00
from typing import Callable, Any, Optional, Awaitable, Dict
2024-08-07 01:07:32 +00:00
from enum import Enum
2024-08-07 04:21:51 +00:00
import time
2024-08-07 01:07:32 +00:00
from dataclasses import dataclass
2024-08-08 19:53:17 +00:00
2024-08-07 01:07:32 +00:00
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
2024-08-07 04:21:51 +00:00
status: str
2024-08-07 01:07:32 +00:00
progress_percentage: float
message: str
already_existed: bool = False
2024-08-07 01:07:32 +00:00
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
2024-08-07 04:21:51 +00:00
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
2024-08-07 04:21:51 +00:00
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
2024-08-07 01:07:32 +00:00
model_name: str,
model_url: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
2024-08-07 04:21:51 +00:00
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file:
return existing_file
2024-08-07 01:07:32 +00:00
2024-08-07 04:21:51 +00:00
try:
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
2024-08-07 04:21:51 +00:00
await progress_callback(relative_path, status)
response = await model_download_request(model_url)
2024-08-07 04:21:51 +00:00
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
2024-08-07 04:21:51 +00:00
await progress_callback(relative_path, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
2024-08-07 04:21:51 +00:00
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
2024-08-07 04:21:51 +00:00
except Exception as e:
logging.error(f"Error in downloading model: {e}")
2024-08-07 04:21:51 +00:00
return await handle_download_error(e, model_name, progress_callback, relative_path)
2024-08-07 01:07:32 +00:00
2024-08-07 04:21:51 +00:00
2024-08-08 19:53:17 +00:00
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
2024-08-07 19:12:01 +00:00
full_model_dir = os.path.join(models_base_dir, model_directory)
2024-08-07 04:21:51 +00:00
os.makedirs(full_model_dir, exist_ok=True)
2024-08-07 01:07:32 +00:00
file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name])
2024-08-07 04:21:51 +00:00
return file_path, relative_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
2024-08-07 01:07:32 +00:00
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
2024-08-07 01:07:32 +00:00
await progress_callback(relative_path, status)
return status
2024-08-07 04:21:51 +00:00
return None
2024-08-07 01:07:32 +00:00
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus:
2024-08-07 19:12:01 +00:00
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
2024-08-07 04:21:51 +00:00
last_update_time = time.time()
2024-08-07 19:12:01 +00:00
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
2024-08-07 19:12:01 +00:00
await progress_callback(relative_path, status)
last_update_time = time.time()
with open(file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
await update_progress()
2024-08-07 19:17:29 +00:00
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
2024-08-07 19:12:01 +00:00
await progress_callback(relative_path, status)
2024-08-07 04:21:51 +00:00
return status
2024-08-07 19:12:01 +00:00
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
2024-08-07 04:21:51 +00:00
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
2024-08-07 04:21:51 +00:00
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
2024-08-07 04:21:51 +00:00
await progress_callback(relative_path, status)
return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True