From 6976ccc5ca44d2ae98f99289fdbb644de723ebbb Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Tue, 6 Aug 2024 18:07:32 -0700 Subject: [PATCH] Add model downloading endpoint. --- model_filemanager/__init__.py | 2 + model_filemanager/download_models.py | 78 +++++++++++++++++++ server.py | 29 ++++++- tests-unit/prompt_server_test/__init__.py | 0 .../download_models_test.py | 68 ++++++++++++++++ tests-unit/requirements.txt | 1 + 6 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 model_filemanager/__init__.py create mode 100644 model_filemanager/download_models.py create mode 100644 tests-unit/prompt_server_test/__init__.py create mode 100644 tests-unit/prompt_server_test/download_models_test.py diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py new file mode 100644 index 00000000..2b2334ed --- /dev/null +++ b/model_filemanager/__init__.py @@ -0,0 +1,2 @@ +# model_manager/__init__.py +from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py new file mode 100644 index 00000000..7cc29b23 --- /dev/null +++ b/model_filemanager/download_models.py @@ -0,0 +1,78 @@ +import aiohttp +import os +from folder_paths import models_dir +from typing import Callable, Any +from enum import Enum + +from dataclasses import dataclass + +class DownloadStatusType(Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + ERROR = "error" + +@dataclass +class DownloadStatus(): + status: DownloadStatusType + progress_percentage: float + message: str + +@dataclass +class DownloadModelResult(): + status: DownloadStatusType + message: str + already_existed: bool + +async def download_model(session: aiohttp.ClientSession, + model_name: str, + model_url: str, + model_directory: str, + progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult: + """ + Asynchronously downloads a model file from a given URL to a specified directory. + + If the file already exists, return success. + Downloads the file in chunks and reports progress as a percentage through the callback function. + """ + + full_model_dir = os.path.join(models_dir, model_directory) + os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists. + file_path = os.path.join(full_model_dir, model_name) + relative_path = '/'.join([model_directory, model_name]) + if os.path.exists(file_path): + status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists") + await progress_callback(relative_path, status) + return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True} + try: + status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") + await progress_callback(relative_path, status) + + async with session.get(model_url) as response: + if response.status != 200: + error_message = f"Failed to download {model_name}. Status code: {response.status}" + status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) + await progress_callback(relative_path, status) + return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False} + + total_size = int(response.headers.get('Content-Length', 0)) + downloaded = 0 + + with open(file_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) + downloaded += len(chunk) + progress = (downloaded / total_size) * 100 if total_size > 0 else 0 + status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}") + await progress_callback(relative_path, status) + + status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") + await progress_callback(relative_path, status) + + return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False} + + except Exception as e: + error_message = f"Error downloading {model_name}: {str(e)}" + status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) + await progress_callback(relative_path, status) + return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False} \ No newline at end of file diff --git a/server.py b/server.py index 1c07f978..8a6e3bf7 100644 --- a/server.py +++ b/server.py @@ -12,7 +12,6 @@ import json import glob import struct import ssl -import hashlib from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from io import BytesIO @@ -28,6 +27,7 @@ import comfy.model_management import node_helpers from app.frontend_management import FrontendManager from app.user_manager import UserManager +from model_filemanager import download_model, DownloadStatus class BinaryEventTypes: @@ -76,6 +76,8 @@ class PromptServer(): self.prompt_queue = None self.loop = loop self.messages = asyncio.Queue() + timeout = aiohttp.ClientTimeout(total=None) # no timeout + self.client_session = aiohttp.ClientSession(timeout=timeout) self.number = 0 middlewares = [cache_control] @@ -559,6 +561,28 @@ class PromptServer(): self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) + + @routes.post("/download") + async def download_handler(request): + async def report_progress(filename: str, status: DownloadStatus): + await self.send_json(filename, { + "progress_percentage": status.progress_percentage, + "status": status.status, + "message": status.message + }) + + data = await request.json() + url = data.get('url') + model_directory = data.get('model_directory') + model_filename = data.get('model_filename') + + if not url or not model_directory or not model_filename: + return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) + + task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress)) + await task + + return web.Response(status=200) def add_routes(self): self.user_manager.add_routes(self.routes) @@ -698,3 +722,6 @@ class PromptServer(): logging.warning(traceback.format_exc()) return json_data + + def close_session(self): + self.client_session.close() diff --git a/tests-unit/prompt_server_test/__init__.py b/tests-unit/prompt_server_test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py new file mode 100644 index 00000000..d5a96cb2 --- /dev/null +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -0,0 +1,68 @@ +import pytest +import aiohttp +import uuid +from unittest.mock import AsyncMock, MagicMock +from model_filemanager import download_model, DownloadStatus, DownloadStatusType + + +async def async_iterator(chunks): + for chunk in chunks: + yield chunk + +@pytest.mark.asyncio +async def test_download_model_success(): + # Create a temporary directory for testing + model_directory = str(uuid.uuid4()) + + # Create a mock session + session = AsyncMock(spec=aiohttp.ClientSession) + + # Mock the response + mock_response = MagicMock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {'Content-Length': '100'} + mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2']) + + session.get.return_value.__aenter__.return_value = mock_response + + # Create a mock progress callback + progress_callback = AsyncMock() + + # Call the function + result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) + + # Assert the expected behavior + assert result['status'] == DownloadStatusType.COMPLETED + assert result['message'] == 'Successfully downloaded model.safetensors' + assert result['already_existed'] is False + relative_path = '/'.join([model_directory, 'model.safetensors']) + progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors')) + + +@pytest.mark.asyncio +async def test_download_model_failure(): + # Create a temporary directory for testing + model_directory = str(uuid.uuid4()) + + # Create a mock session + session = AsyncMock(spec=aiohttp.ClientSession) + + # Mock the response with an error status code + mock_response = MagicMock(spec=aiohttp.ClientResponse) + mock_response.status = 500 + session.get.return_value.__aenter__.return_value = mock_response + + # Create a mock progress callback + progress_callback = AsyncMock() + + # Call the function + result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) + print(result) + + # Assert the expected behavior + assert result['status'] == DownloadStatusType.ERROR + assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500' + assert result['already_existed'] is False + + relative_path = '/'.join([model_directory, 'model.safetensors']) + progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500')) \ No newline at end of file diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 0587502f..1bac12eb 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -1 +1,2 @@ pytest>=7.8.0 +pytest-aiohttp \ No newline at end of file