Consolidate DownloadStatus and DownloadModelResult

This commit is contained in:
Robin Huang 2024-08-07 16:44:56 -07:00
parent a6d8a93fa1
commit c1d78d6890
3 changed files with 47 additions and 42 deletions

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory

View File

@ -16,32 +16,23 @@ class DownloadStatusType(Enum):
ERROR = "error" ERROR = "error"
@dataclass @dataclass
class DownloadStatus(): class DownloadModelStatus():
status: str status: str
progress_percentage: float progress_percentage: float
message: str message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str): def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage self.progress_percentage = progress_percentage
self.message = message self.message = message
@dataclass
class DownloadModelResult():
status: str
message: str
already_existed: bool
def __init__(self, status: DownloadStatusType, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.message = message
self.already_existed = already_existed self.already_existed = already_existed
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_sub_directory: str, model_sub_directory: str,
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult: progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus:
""" """
Download a model file from a given URL into the models directory. Download a model file from a given URL into the models directory.
@ -55,15 +46,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
model_sub_directory (str): model_sub_directory (str):
The subdirectory within the main models directory where the model The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.). should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]): progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates. An asynchronous function to call with progress updates.
Returns: Returns:
DownloadModelResult: The result of the download operation. DownloadModelStatus: The result of the download operation.
""" """
if not validate_model_subdirectory(model_sub_directory): if not validate_model_subdirectory(model_sub_directory):
return DownloadModelResult( return DownloadModelStatus(
DownloadStatusType.ERROR, DownloadStatusType.ERROR,
0,
"Invalid model subdirectory", "Invalid model subdirectory",
False False
) )
@ -74,16 +66,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
return existing_file return existing_file
try: try:
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
response = await model_download_request(model_url) response = await model_download_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message) logging.error(error_message)
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
@ -99,15 +91,23 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
relative_path = '/'.join([model_directory, model_name]) relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path return file_path, relative_path
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]: async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"{model_name} already exists", True) return status
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult: async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus:
try: try:
total_size = int(response.headers.get('Content-Length', 0)) total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0 downloaded = 0
@ -116,7 +116,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
async def update_progress(): async def update_progress():
nonlocal last_update_time nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0 progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}") status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
last_update_time = time.time() last_update_time = time.time()
@ -136,20 +136,23 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
await update_progress() await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False) return status
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}") logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult: async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) return status
def validate_model_subdirectory(model_subdirectory: str) -> bool: def validate_model_subdirectory(model_subdirectory: str) -> bool:
""" """

View File

@ -4,7 +4,7 @@ from aiohttp import ClientResponse
import itertools import itertools
import os import os
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus
class AsyncIteratorMock: class AsyncIteratorMock:
""" """
@ -73,7 +73,7 @@ async def test_download_model_success():
) )
# Assert the result # Assert the result
assert isinstance(result, DownloadModelResult) assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.bin' assert result.message == 'Successfully downloaded model.bin'
assert result.status == 'completed' assert result.status == 'completed'
assert result.already_existed is False assert result.already_existed is False
@ -84,13 +84,13 @@ async def test_download_model_success():
# Check initial call # Check initial call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.bin', 'checkpoints/model.bin',
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin") DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False)
) )
# Check final call # Check final call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.bin', 'checkpoints/model.bin',
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin") DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False)
) )
# Verify file writing # Verify file writing
@ -123,7 +123,7 @@ async def test_download_model_url_request_failure():
) )
# Assert the expected behavior # Assert the expected behavior
assert isinstance(result, DownloadModelResult) assert isinstance(result, DownloadModelStatus)
assert result.status == 'error' assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404' assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False assert result.already_existed is False
@ -131,18 +131,20 @@ async def test_download_model_url_request_failure():
# Check that progress_callback was called with the correct arguments # Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors', 'mock_directory/model.safetensors',
DownloadStatus( DownloadModelStatus(
status=DownloadStatusType.PENDING, status=DownloadStatusType.PENDING,
progress_percentage=0, progress_percentage=0,
message='Starting download of model.safetensors' message='Starting download of model.safetensors',
already_existed=False
) )
) )
mock_progress_callback.assert_called_with( mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors', 'mock_directory/model.safetensors',
DownloadStatus( DownloadModelStatus(
status=DownloadStatusType.ERROR, status=DownloadStatusType.ERROR,
progress_percentage=0, progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404' message='Failed to download model.safetensors. Status code: 404',
already_existed=False
) )
) )
@ -165,7 +167,7 @@ async def test_download_model_invalid_model_subdirectory():
) )
# Assert the result # Assert the result
assert isinstance(result, DownloadModelResult) assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory' assert result.message == 'Invalid model subdirectory'
assert result.status == 'error' assert result.status == 'error'
assert result.already_existed is False assert result.already_existed is False
@ -202,7 +204,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):
mock_callback.assert_called_once_with( mock_callback.assert_called_once_with(
"test/existing_model.bin", "test/existing_model.bin",
DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists") DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -235,7 +237,7 @@ async def test_track_download_progress_no_content_length():
# Check that progress was reported even without knowing the total size # Check that progress was reported even without knowing the total size
mock_callback.assert_any_call( mock_callback.assert_any_call(
'models/model.bin', 'models/model.bin',
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin") DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False)
) )
@pytest.mark.asyncio @pytest.mark.asyncio