Add model downloading endpoint.

This commit is contained in:
Robin Huang 2024-08-06 18:07:32 -07:00
parent b334605a66
commit 6976ccc5ca
6 changed files with 177 additions and 1 deletions

View File

@ -0,0 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType

View File

@ -0,0 +1,78 @@
import aiohttp
import os
from folder_paths import models_dir
from typing import Callable, Any
from enum import Enum
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadStatus():
status: DownloadStatusType
progress_percentage: float
message: str
@dataclass
class DownloadModelResult():
status: DownloadStatusType
message: str
already_existed: bool
async def download_model(session: aiohttp.ClientSession,
model_name: str,
model_url: str,
model_directory: str,
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
"""
Asynchronously downloads a model file from a given URL to a specified directory.
If the file already exists, return success.
Downloads the file in chunks and reports progress as a percentage through the callback function.
"""
full_model_dir = os.path.join(models_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists.
file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name])
if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True}
try:
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status)
async with session.get(model_url) as response:
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False}
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
with open(file_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
downloaded += len(chunk)
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False}
except Exception as e:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False}

View File

@ -12,7 +12,6 @@ import json
import glob
import struct
import ssl
import hashlib
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@ -28,6 +27,7 @@ import comfy.model_management
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus
class BinaryEventTypes:
@ -76,6 +76,8 @@ class PromptServer():
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
self.number = 0
middlewares = [cache_control]
@ -560,6 +562,28 @@ class PromptServer():
return web.Response(status=200)
@routes.post("/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus):
await self.send_json(filename, {
"progress_percentage": status.progress_percentage,
"status": status.status,
"message": status.message
})
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress))
await task
return web.Response(status=200)
def add_routes(self):
self.user_manager.add_routes(self.routes)
@ -698,3 +722,6 @@ class PromptServer():
logging.warning(traceback.format_exc())
return json_data
def close_session(self):
self.client_session.close()

View File

@ -0,0 +1,68 @@
import pytest
import aiohttp
import uuid
from unittest.mock import AsyncMock, MagicMock
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
async def async_iterator(chunks):
for chunk in chunks:
yield chunk
@pytest.mark.asyncio
async def test_download_model_success():
# Create a temporary directory for testing
model_directory = str(uuid.uuid4())
# Create a mock session
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '100'}
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
session.get.return_value.__aenter__.return_value = mock_response
# Create a mock progress callback
progress_callback = AsyncMock()
# Call the function
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
# Assert the expected behavior
assert result['status'] == DownloadStatusType.COMPLETED
assert result['message'] == 'Successfully downloaded model.safetensors'
assert result['already_existed'] is False
relative_path = '/'.join([model_directory, 'model.safetensors'])
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
@pytest.mark.asyncio
async def test_download_model_failure():
# Create a temporary directory for testing
model_directory = str(uuid.uuid4())
# Create a mock session
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response with an error status code
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 500
session.get.return_value.__aenter__.return_value = mock_response
# Create a mock progress callback
progress_callback = AsyncMock()
# Call the function
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
print(result)
# Assert the expected behavior
assert result['status'] == DownloadStatusType.ERROR
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
assert result['already_existed'] is False
relative_path = '/'.join([model_directory, 'model.safetensors'])
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))

View File

@ -1 +1,2 @@
pytest>=7.8.0
pytest-aiohttp