This commit is contained in:
Robin Huang 2024-08-07 12:12:01 -07:00
parent c2cd09540d
commit 3881f03545
5 changed files with 288 additions and 95 deletions

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress

View File

@ -1,7 +1,9 @@
import aiohttp import aiohttp
import os import os
import traceback
import logging
from folder_paths import models_dir from folder_paths import models_dir
from typing import Callable, Any, Optional from typing import Callable, Any, Optional, Awaitable
from enum import Enum from enum import Enum
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -34,13 +36,13 @@ class DownloadModelResult():
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
async def download_model(session: aiohttp.ClientSession, async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_directory: str, model_directory: str,
progress_callback: Callable[[str, DownloadStatus], Any]) -> DownloadModelResult: progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
file_path, relative_path = create_model_path(model_name, model_directory)
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file: if existing_file:
return existing_file return existing_file
@ -49,8 +51,7 @@ async def download_model(session: aiohttp.ClientSession,
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
response = await session.get(model_url) response = await make_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
@ -60,17 +61,22 @@ async def download_model(session: aiohttp.ClientSession,
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback, relative_path)
def create_model_path(model_name: str, model_directory: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_dir, model_directory) async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
return await session.get(url)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True) os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name) file_path = os.path.join(full_model_dir, model_name)
relative_path = '/'.join([model_directory, model_name]) relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path return file_path, relative_path
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> Optional[DownloadModelResult]: async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists") status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
@ -78,34 +84,43 @@ async def check_file_exists(file_path: str, model_name: str, progress_callback:
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str, interval: float = 1.0) -> DownloadModelResult: async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
total_size = int(response.headers.get('Content-Length', 0)) try:
downloaded = 0 total_size = int(response.headers.get('Content-Length', 0))
last_update_time = time.time() downloaded = 0
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
last_update_time = time.time() last_update_time = time.time()
with open(file_path, 'wb') as f: async def update_progress():
async for chunk in response.content.iter_chunked(8192): nonlocal last_update_time
f.write(chunk) progress = (downloaded / total_size) * 100 if total_size > 0 else 0
downloaded += len(chunk) status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
await progress_callback(relative_path, status)
# Check if it's time to update progress last_update_time = time.time()
if time.time() - last_update_time >= interval:
await update_progress()
# Ensure we send a final update with open(file_path, 'wb') as f:
await update_progress() chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") await update_progress()
await progress_callback(relative_path, status)
logging.info(f"Download completed. Total downloaded: {downloaded}")
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False) return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult: async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"

View File

@ -28,7 +28,7 @@ import node_helpers
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus from model_filemanager import download_model, DownloadStatus
from typing import Optional
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -76,7 +76,7 @@ class PromptServer():
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.client_session = None self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0 self.number = 0
middlewares = [cache_control] middlewares = [cache_control]
@ -579,7 +579,12 @@ class PromptServer():
if not url or not model_directory or not model_filename: if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress)) session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress))
await task await task
return web.Response(status=200) return web.Response(status=200)
@ -726,6 +731,3 @@ class PromptServer():
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
return json_data return json_data
def close_session(self):
self.client_session.close()

View File

@ -1,68 +1,243 @@
import pytest import pytest
import aiohttp import aiohttp
import uuid from aiohttp import ClientResponse
from unittest.mock import AsyncMock, MagicMock import itertools
from model_filemanager import download_model, DownloadStatus, DownloadStatusType import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
class AsyncIteratorMock:
def __init__(self, seq):
self.iter = iter(seq)
async def async_iterator(chunks): def __aiter__(self):
for chunk in chunks: return self
yield chunk
async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration
class ContentMock:
def __init__(self, chunks):
self.chunks = chunks
def iter_chunked(self, chunk_size):
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_success(): async def test_download_model_success():
# Create a temporary directory for testing # Mock dependencies
model_directory = str(uuid.uuid4()) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
# Create a mock session
session = AsyncMock(spec=aiohttp.ClientSession)
# Mock the response
mock_response = MagicMock(spec=aiohttp.ClientResponse)
mock_response.status = 200 mock_response.status = 200
mock_response.headers = {'Content-Length': '100'} mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
session.get.return_value.__aenter__.return_value = mock_response # Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
# Create a mock progress callback mock_response.content = ContentMock(chunks)
progress_callback = AsyncMock()
mock_make_request = AsyncMock(return_value=mock_response)
# Call the function mock_progress_callback = MagicMock()
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
# Mock file operations
# Assert the expected behavior mock_open = MagicMock()
assert result['status'] == DownloadStatusType.COMPLETED mock_file = MagicMock()
assert result['message'] == 'Successfully downloaded model.safetensors' mock_open.return_value.__enter__.return_value = mock_file
assert result['already_existed'] is False time_values = itertools.count(0, 0.1)
relative_path = '/'.join([model_directory, 'model.safetensors'])
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors')) with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
mock_make_request,
'model.bin',
'http://example.com/model.bin',
'checkpoints',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelResult)
assert result.message == 'Successfully downloaded model.bin'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin")
)
# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin")
)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.bin')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_failure(): async def test_download_model_url_request_failure():
# Create a temporary directory for testing # Mock dependencies
model_directory = str(uuid.uuid4()) mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
# Create a mock session mock_get = AsyncMock(return_value=mock_response)
session = AsyncMock(spec=aiohttp.ClientSession) mock_progress_callback = AsyncMock()
# Mock the response with an error status code # Mock the create_model_path function
mock_response = MagicMock(spec=aiohttp.ClientResponse) with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
mock_response.status = 500 # Mock the check_file_exists function to return None (file doesn't exist)
session.get.return_value.__aenter__.return_value = mock_response with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
# Create a mock progress callback result = await download_model(
progress_callback = AsyncMock() mock_get,
'model.safetensors',
# Call the function 'http://example.com/model.safetensors',
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback) 'mock_directory',
print(result) mock_progress_callback
)
# Assert the expected behavior # Assert the expected behavior
assert result['status'] == DownloadStatusType.ERROR assert isinstance(result, DownloadModelResult)
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500' assert result.status == 'error'
assert result['already_existed'] is False assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors',
DownloadStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
message='Starting download of model.safetensors'
)
)
mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors',
DownloadStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404'
)
)
# Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors')
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
relative_path = '/'.join([model_directory, 'model.safetensors']) model_name = "test_model.bin"
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500')) model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path))
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.bin"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin")
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.bin already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"test/existing_model.bin",
DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists")
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.bin"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin")
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=0.1
)
assert result.status == "completed"
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.bin',
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin")
)
@pytest.mark.asyncio
async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=1.0
)
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100

View File

@ -1,2 +1,3 @@
pytest>=7.8.0 pytest>=7.8.0
pytest-aiohttp pytest-aiohttp
pytest-asyncio