Add model downloading endpoint.

This commit is contained in:
Robin Huang 2024-08-06 18:07:32 -07:00
parent b334605a66
commit 6976ccc5ca
6 changed files with 177 additions and 1 deletions

View File

@ -0,0 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType

View File

@ -0,0 +1,78 @@
import aiohttp
import os
from folder_paths import models_dir
from typing import Callable, Any
from enum import Enum
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadStatus():
status: DownloadStatusType
progress_percentage: float
message: str
@dataclass
class DownloadModelResult():
status: DownloadStatusType
message: str
already_existed: bool
async def download_model(session: aiohttp.ClientSession,
model_name: str,
model_url: str,
model_directory: str,
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
"""
Asynchronously downloads a model file from a given URL to a specified directory.
If the file already exists, return success.
Downloads the file in chunks and reports progress as a percentage through the callback function.
"""
full_model_dir = os.path.join(models_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists.
file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name])
if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True}
try:
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status)
async with session.get(model_url) as response:
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False}
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
with open(file_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
downloaded += len(chunk)
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False}
except Exception as e:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False}

View File

@ -12,7 +12,6 @@ import json
import glob import glob
import struct import struct
import ssl import ssl
import hashlib
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -28,6 +27,7 @@ import comfy.model_management
import node_helpers import node_helpers
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus
class BinaryEventTypes: class BinaryEventTypes:
@ -76,6 +76,8 @@ class PromptServer():
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
self.number = 0 self.number = 0
middlewares = [cache_control] middlewares = [cache_control]
@ -560,6 +562,28 @@ class PromptServer():
return web.Response(status=200) return web.Response(status=200)
@routes.post("/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus):
await self.send_json(filename, {
"progress_percentage": status.progress_percentage,
"status": status.status,
"message": status.message
})
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress))
await task
return web.Response(status=200)
def add_routes(self): def add_routes(self):
self.user_manager.add_routes(self.routes) self.user_manager.add_routes(self.routes)
@ -698,3 +722,6 @@ class PromptServer():
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
return json_data return json_data
def close_session(self):
self.client_session.close()

View File

@ -0,0 +1,68 @@
import pytest
import aiohttp
import uuid
from unittest.mock import AsyncMock, MagicMock
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
async def async_iterator(chunks):
for chunk in chunks:
yield chunk
@pytest.mark.asyncio
async def test_download_model_success():
# Create a temporary directory for testing
model_directory = str(uuid.uuid4())
# Create a mock session
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '100'}
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
session.get.return_value.__aenter__.return_value = mock_response
# Create a mock progress callback
progress_callback = AsyncMock()
# Call the function
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
# Assert the expected behavior
assert result['status'] == DownloadStatusType.COMPLETED
assert result['message'] == 'Successfully downloaded model.safetensors'
assert result['already_existed'] is False
relative_path = '/'.join([model_directory, 'model.safetensors'])
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
@pytest.mark.asyncio
async def test_download_model_failure():
# Create a temporary directory for testing
model_directory = str(uuid.uuid4())
# Create a mock session
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response with an error status code
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 500
session.get.return_value.__aenter__.return_value = mock_response
# Create a mock progress callback
progress_callback = AsyncMock()
# Call the function
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
print(result)
# Assert the expected behavior
assert result['status'] == DownloadStatusType.ERROR
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
assert result['already_existed'] is False
relative_path = '/'.join([model_directory, 'model.safetensors'])
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))

View File

@ -1 +1,2 @@
pytest>=7.8.0 pytest>=7.8.0
pytest-aiohttp