From 3881f035455cc1092fc273255f879ce4137c4c18 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Wed, 7 Aug 2024 12:12:01 -0700 Subject: [PATCH] Fixed. --- model_filemanager/__init__.py | 2 +- model_filemanager/download_models.py | 81 +++-- server.py | 14 +- .../download_models_test.py | 283 ++++++++++++++---- tests-unit/requirements.txt | 3 +- 5 files changed, 288 insertions(+), 95 deletions(-) diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index 2b2334ed..5c71ae52 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType +from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 60a657a5..5173a2a6 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -1,7 +1,9 @@ import aiohttp import os +import traceback +import logging from folder_paths import models_dir -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Awaitable from enum import Enum import time from dataclasses import dataclass @@ -34,13 +36,13 @@ class DownloadModelResult(): self.message = message self.already_existed = already_existed -async def download_model(session: aiohttp.ClientSession, +async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, model_url: str, model_directory: str, - progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult: - file_path, relative_path = create_model_path(model_name, model_directory) - + progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult: + + file_path, relative_path = create_model_path(model_name, model_directory, models_dir) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) if existing_file: return existing_file @@ -49,8 +51,7 @@ async def download_model(session: aiohttp.ClientSession, status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") await progress_callback(relative_path, status) - response = await session.get(model_url) - + response = await make_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) @@ -60,17 +61,22 @@ async def download_model(session: aiohttp.ClientSession, return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) except Exception as e: + logging.error(f"Error in track_download_progress: {e}") return await handle_download_error(e, model_name, progress_callback, relative_path) -def create_model_path(model_name: str, model_directory: str) -> tuple[str, str]: - full_model_dir = os.path.join(models_dir, model_directory) + +async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse: + return await session.get(url) + +def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: + full_model_dir = os.path.join(models_base_dir, model_directory) os.makedirs(full_model_dir, exist_ok=True) file_path = os.path.join(full_model_dir, model_name) relative_path = '/'.join([model_directory, model_name]) return file_path, relative_path -async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> Optional[DownloadModelResult]: +async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]: if os.path.exists(file_path): status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists") await progress_callback(relative_path, status) @@ -78,34 +84,43 @@ async def check_file_exists(file_path: str, model_name: str, progress_callback: return None -async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str, interval: float = 1.0) -> DownloadModelResult: - total_size = int(response.headers.get('Content-Length', 0)) - downloaded = 0 - last_update_time = time.time() - - async def update_progress(): - nonlocal last_update_time - progress = (downloaded / total_size) * 100 if total_size > 0 else 0 - status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}") - await progress_callback(relative_path, status) +async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult: + try: + total_size = int(response.headers.get('Content-Length', 0)) + downloaded = 0 last_update_time = time.time() - with open(file_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - f.write(chunk) - downloaded += len(chunk) - - # Check if it's time to update progress - if time.time() - last_update_time >= interval: - await update_progress() + async def update_progress(): + nonlocal last_update_time + progress = (downloaded / total_size) * 100 if total_size > 0 else 0 + status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}") + await progress_callback(relative_path, status) + last_update_time = time.time() - # Ensure we send a final update - await update_progress() + with open(file_path, 'wb') as f: + chunk_iterator = response.content.iter_chunked(8192) + while True: + try: + chunk = await chunk_iterator.__anext__() + except StopAsyncIteration: + break + f.write(chunk) + downloaded += len(chunk) + + if time.time() - last_update_time >= interval: + await update_progress() - status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") - await progress_callback(relative_path, status) + await update_progress() + + logging.info(f"Download completed. Total downloaded: {downloaded}") + status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") + await progress_callback(relative_path, status) - return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False) + return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False) + except Exception as e: + logging.error(f"Error in track_download_progress: {e}") + logging.error(traceback.format_exc()) + return await handle_download_error(e, model_name, progress_callback, relative_path) async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult: error_message = f"Error downloading {model_name}: {str(e)}" diff --git a/server.py b/server.py index 93423c5b..db34c998 100644 --- a/server.py +++ b/server.py @@ -28,7 +28,7 @@ import node_helpers from app.frontend_management import FrontendManager from app.user_manager import UserManager from model_filemanager import download_model, DownloadStatus - +from typing import Optional class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -76,7 +76,7 @@ class PromptServer(): self.prompt_queue = None self.loop = loop self.messages = asyncio.Queue() - self.client_session = None + self.client_session:Optional[aiohttp.ClientSession] = None self.number = 0 middlewares = [cache_control] @@ -579,7 +579,12 @@ class PromptServer(): if not url or not model_directory or not model_filename: return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) - task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress)) + session = self.client_session + if session is None: + logging.error("Client session is not initialized") + return web.Response(status=500) + + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress)) await task return web.Response(status=200) @@ -726,6 +731,3 @@ class PromptServer(): logging.warning(traceback.format_exc()) return json_data - - def close_session(self): - self.client_session.close() diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index d5a96cb2..9f0ac45b 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -1,68 +1,243 @@ import pytest import aiohttp -import uuid -from unittest.mock import AsyncMock, MagicMock -from model_filemanager import download_model, DownloadStatus, DownloadStatusType +from aiohttp import ClientResponse +import itertools +import os +from unittest.mock import AsyncMock, patch, MagicMock +from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType +class AsyncIteratorMock: + def __init__(self, seq): + self.iter = iter(seq) -async def async_iterator(chunks): - for chunk in chunks: - yield chunk + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + +class ContentMock: + def __init__(self, chunks): + self.chunks = chunks + + def iter_chunked(self, chunk_size): + return AsyncIteratorMock(self.chunks) @pytest.mark.asyncio async def test_download_model_success(): - # Create a temporary directory for testing - model_directory = str(uuid.uuid4()) - - # Create a mock session - session = AsyncMock(spec=aiohttp.ClientSession) - - # Mock the response - mock_response = MagicMock(spec=aiohttp.ClientResponse) + # Mock dependencies + mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.status = 200 - mock_response.headers = {'Content-Length': '100'} - mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2']) + mock_response.headers = {'Content-Length': '1000'} - session.get.return_value.__aenter__.return_value = mock_response - - # Create a mock progress callback - progress_callback = AsyncMock() - - # Call the function - result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) - - # Assert the expected behavior - assert result['status'] == DownloadStatusType.COMPLETED - assert result['message'] == 'Successfully downloaded model.safetensors' - assert result['already_existed'] is False - relative_path = '/'.join([model_directory, 'model.safetensors']) - progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors')) + # Create a mock for content that returns an async iterator directly + chunks = [b'a' * 500, b'b' * 300, b'c' * 200] + mock_response.content = ContentMock(chunks) + + mock_make_request = AsyncMock(return_value=mock_response) + mock_progress_callback = MagicMock() + + # Mock file operations + mock_open = MagicMock() + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file + time_values = itertools.count(0, 0.1) + + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \ + patch('model_filemanager.check_file_exists', return_value=None), \ + patch('builtins.open', mock_open), \ + patch('time.time', side_effect=time_values): # Simulate time passing + + result = await download_model( + mock_make_request, + 'model.bin', + 'http://example.com/model.bin', + 'checkpoints', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelResult) + assert result.message == 'Successfully downloaded model.bin' + assert result.status == 'completed' + assert result.already_existed is False + + # Check progress callback calls + assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion + # Check initial call + mock_progress_callback.assert_any_call( + 'checkpoints/model.bin', + DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin") + ) + + # Check final call + mock_progress_callback.assert_any_call( + 'checkpoints/model.bin', + DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin") + ) + + # Verify file writing + mock_file.write.assert_any_call(b'a' * 500) + mock_file.write.assert_any_call(b'b' * 300) + mock_file.write.assert_any_call(b'c' * 200) + + # Verify request was made + mock_make_request.assert_called_once_with('http://example.com/model.bin') @pytest.mark.asyncio -async def test_download_model_failure(): - # Create a temporary directory for testing - model_directory = str(uuid.uuid4()) - - # Create a mock session - session = AsyncMock(spec=aiohttp.ClientSession) - - # Mock the response with an error status code - mock_response = MagicMock(spec=aiohttp.ClientResponse) - mock_response.status = 500 - session.get.return_value.__aenter__.return_value = mock_response - - # Create a mock progress callback - progress_callback = AsyncMock() - - # Call the function - result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) - print(result) - +async def test_download_model_url_request_failure(): + # Mock dependencies + mock_response = AsyncMock(spec=ClientResponse) + mock_response.status = 404 # Simulate a "Not Found" error + mock_get = AsyncMock(return_value=mock_response) + mock_progress_callback = AsyncMock() + + # Mock the create_model_path function + with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): + # Mock the check_file_exists function to return None (file doesn't exist) + with patch('model_filemanager.check_file_exists', return_value=None): + # Call the function + result = await download_model( + mock_get, + 'model.safetensors', + 'http://example.com/model.safetensors', + 'mock_directory', + mock_progress_callback + ) + # Assert the expected behavior - assert result['status'] == DownloadStatusType.ERROR - assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500' - assert result['already_existed'] is False + assert isinstance(result, DownloadModelResult) + assert result.status == 'error' + assert result.message == 'Failed to download model.safetensors. Status code: 404' + assert result.already_existed is False + + # Check that progress_callback was called with the correct arguments + mock_progress_callback.assert_any_call( + 'mock_directory/model.safetensors', + DownloadStatus( + status=DownloadStatusType.PENDING, + progress_percentage=0, + message='Starting download of model.safetensors' + ) + ) + mock_progress_callback.assert_called_with( + 'mock_directory/model.safetensors', + DownloadStatus( + status=DownloadStatusType.ERROR, + progress_percentage=0, + message='Failed to download model.safetensors. Status code: 404' + ) + ) + + # Verify that the get method was called with the correct URL + mock_get.assert_called_once_with('http://example.com/model.safetensors') + +# For create_model_path function +def test_create_model_path(tmp_path, monkeypatch): + mock_models_dir = tmp_path / "models" + monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) - relative_path = '/'.join([model_directory, 'model.safetensors']) - progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500')) \ No newline at end of file + model_name = "test_model.bin" + model_directory = "test_dir" + + file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) + + assert file_path == str(mock_models_dir / model_directory / model_name) + assert relative_path == f"{model_directory}/{model_name}" + assert os.path.exists(os.path.dirname(file_path)) + + +@pytest.mark.asyncio +async def test_check_file_exists_when_file_exists(tmp_path): + file_path = tmp_path / "existing_model.bin" + file_path.touch() # Create an empty file + + mock_callback = AsyncMock() + + result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin") + + assert result is not None + assert result.status == "completed" + assert result.message == "existing_model.bin already exists" + assert result.already_existed is True + + mock_callback.assert_called_once_with( + "test/existing_model.bin", + DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists") + ) + +@pytest.mark.asyncio +async def test_check_file_exists_when_file_does_not_exist(tmp_path): + file_path = tmp_path / "non_existing_model.bin" + + mock_callback = AsyncMock() + + result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin") + + assert result is None + mock_callback.assert_not_called() + +@pytest.mark.asyncio +async def test_track_download_progress_no_content_length(): + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.headers = {} # No Content-Length header + mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) + + mock_callback = AsyncMock() + mock_open = MagicMock(return_value=MagicMock()) + + with patch('builtins.open', mock_open): + result = await track_download_progress( + mock_response, '/mock/path/model.bin', 'model.bin', + mock_callback, 'models/model.bin', interval=0.1 + ) + + assert result.status == "completed" + # Check that progress was reported even without knowing the total size + mock_callback.assert_any_call( + 'models/model.bin', + DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin") + ) + +@pytest.mark.asyncio +async def test_track_download_progress_interval(): + mock_response = AsyncMock(spec=aiohttp.ClientResponse) + mock_response.headers = {'Content-Length': '1000'} + mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) + + mock_callback = AsyncMock() + mock_open = MagicMock(return_value=MagicMock()) + + # Create a mock time function that returns incremental float values + mock_time = MagicMock() + mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks + + with patch('builtins.open', mock_open), \ + patch('time.time', mock_time): + await track_download_progress( + mock_response, '/mock/path/model.bin', 'model.bin', + mock_callback, 'models/model.bin', interval=1.0 + ) + + # Print out the actual call count and the arguments of each call for debugging + print(f"mock_callback was called {mock_callback.call_count} times") + for i, call in enumerate(mock_callback.call_args_list): + args, kwargs = call + print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") + + # Assert that progress was updated at least 3 times (start, at least one interval, and end) + assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" + + # Verify the first and last calls + first_call = mock_callback.call_args_list[0] + assert first_call[0][1].status == "in_progress" + # Allow for some initial progress, but it should be less than 50% + assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%" + + last_call = mock_callback.call_args_list[-1] + assert last_call[0][1].status == "completed" + assert last_call[0][1].progress_percentage == 100 \ No newline at end of file diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 1bac12eb..4533a0e3 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -1,2 +1,3 @@ pytest>=7.8.0 -pytest-aiohttp \ No newline at end of file +pytest-aiohttp +pytest-asyncio \ No newline at end of file