mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 18:35:17 +00:00
Add model downloading endpoint.
This commit is contained in:
parent
b334605a66
commit
6976ccc5ca
2
model_filemanager/__init__.py
Normal file
2
model_filemanager/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# model_manager/__init__.py
|
||||||
|
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType
|
78
model_filemanager/download_models.py
Normal file
78
model_filemanager/download_models.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import aiohttp
|
||||||
|
import os
|
||||||
|
from folder_paths import models_dir
|
||||||
|
from typing import Callable, Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
class DownloadStatusType(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadStatus():
|
||||||
|
status: DownloadStatusType
|
||||||
|
progress_percentage: float
|
||||||
|
message: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadModelResult():
|
||||||
|
status: DownloadStatusType
|
||||||
|
message: str
|
||||||
|
already_existed: bool
|
||||||
|
|
||||||
|
async def download_model(session: aiohttp.ClientSession,
|
||||||
|
model_name: str,
|
||||||
|
model_url: str,
|
||||||
|
model_directory: str,
|
||||||
|
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult:
|
||||||
|
"""
|
||||||
|
Asynchronously downloads a model file from a given URL to a specified directory.
|
||||||
|
|
||||||
|
If the file already exists, return success.
|
||||||
|
Downloads the file in chunks and reports progress as a percentage through the callback function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
full_model_dir = os.path.join(models_dir, model_directory)
|
||||||
|
os.makedirs(full_model_dir, exist_ok=True) # Ensure the directory exists.
|
||||||
|
file_path = os.path.join(full_model_dir, model_name)
|
||||||
|
relative_path = '/'.join([model_directory, model_name])
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return {"status": DownloadStatusType.COMPLETED, "message": f"{model_name} already exists", "already_existed": True}
|
||||||
|
try:
|
||||||
|
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
async with session.get(model_url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
|
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return {"status": DownloadStatusType.ERROR, "message": f"Failed to download {model_name}. Status code: {response.status} ", "already_existed": False}
|
||||||
|
|
||||||
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
downloaded = 0
|
||||||
|
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
|
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
|
return {"status": DownloadStatusType.COMPLETED, "message": f"Successfully downloaded {model_name}", "already_existed": False}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
|
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||||
|
await progress_callback(relative_path, status)
|
||||||
|
return {"status": DownloadStatusType.ERROR, "message": error_message, "already_existed": False}
|
29
server.py
29
server.py
@ -12,7 +12,6 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
import hashlib
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -28,6 +27,7 @@ import comfy.model_management
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
|
from model_filemanager import download_model, DownloadStatus
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
@ -76,6 +76,8 @@ class PromptServer():
|
|||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
self.number = 0
|
self.number = 0
|
||||||
|
|
||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
@ -559,6 +561,28 @@ class PromptServer():
|
|||||||
self.prompt_queue.delete_history_item(id_to_delete)
|
self.prompt_queue.delete_history_item(id_to_delete)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
@routes.post("/download")
|
||||||
|
async def download_handler(request):
|
||||||
|
async def report_progress(filename: str, status: DownloadStatus):
|
||||||
|
await self.send_json(filename, {
|
||||||
|
"progress_percentage": status.progress_percentage,
|
||||||
|
"status": status.status,
|
||||||
|
"message": status.message
|
||||||
|
})
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
url = data.get('url')
|
||||||
|
model_directory = data.get('model_directory')
|
||||||
|
model_filename = data.get('model_filename')
|
||||||
|
|
||||||
|
if not url or not model_directory or not model_filename:
|
||||||
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
|
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress))
|
||||||
|
await task
|
||||||
|
|
||||||
|
return web.Response(status=200)
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
@ -698,3 +722,6 @@ class PromptServer():
|
|||||||
logging.warning(traceback.format_exc())
|
logging.warning(traceback.format_exc())
|
||||||
|
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
|
def close_session(self):
|
||||||
|
self.client_session.close()
|
||||||
|
0
tests-unit/prompt_server_test/__init__.py
Normal file
0
tests-unit/prompt_server_test/__init__.py
Normal file
68
tests-unit/prompt_server_test/download_models_test.py
Normal file
68
tests-unit/prompt_server_test/download_models_test.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import pytest
|
||||||
|
import aiohttp
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
|
||||||
|
|
||||||
|
|
||||||
|
async def async_iterator(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_success():
|
||||||
|
# Create a temporary directory for testing
|
||||||
|
model_directory = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create a mock session
|
||||||
|
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||||
|
|
||||||
|
# Mock the response
|
||||||
|
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.headers = {'Content-Length': '100'}
|
||||||
|
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
|
||||||
|
|
||||||
|
session.get.return_value.__aenter__.return_value = mock_response
|
||||||
|
|
||||||
|
# Create a mock progress callback
|
||||||
|
progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||||
|
|
||||||
|
# Assert the expected behavior
|
||||||
|
assert result['status'] == DownloadStatusType.COMPLETED
|
||||||
|
assert result['message'] == 'Successfully downloaded model.safetensors'
|
||||||
|
assert result['already_existed'] is False
|
||||||
|
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||||
|
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_failure():
|
||||||
|
# Create a temporary directory for testing
|
||||||
|
model_directory = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create a mock session
|
||||||
|
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||||
|
|
||||||
|
# Mock the response with an error status code
|
||||||
|
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||||
|
mock_response.status = 500
|
||||||
|
session.get.return_value.__aenter__.return_value = mock_response
|
||||||
|
|
||||||
|
# Create a mock progress callback
|
||||||
|
progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Assert the expected behavior
|
||||||
|
assert result['status'] == DownloadStatusType.ERROR
|
||||||
|
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
|
||||||
|
assert result['already_existed'] is False
|
||||||
|
|
||||||
|
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||||
|
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))
|
@ -1 +1,2 @@
|
|||||||
pytest>=7.8.0
|
pytest>=7.8.0
|
||||||
|
pytest-aiohttp
|
Loading…
Reference in New Issue
Block a user