Use pydantic.

This commit is contained in:
Robin Huang 2024-08-08 13:06:26 -07:00
parent db1ce51fdf
commit 7461e8eb00
3 changed files with 57 additions and 34 deletions

View File

@ -8,35 +8,25 @@ import re
from typing import Callable, Any, Optional, Awaitable, Dict from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum from enum import Enum
import time import time
from dataclasses import dataclass from pydantic import BaseModel, Field
class DownloadStatusType(str, Enum):
class DownloadStatusType(Enum):
PENDING = "pending" PENDING = "pending"
IN_PROGRESS = "in_progress" IN_PROGRESS = "in_progress"
COMPLETED = "completed" COMPLETED = "completed"
ERROR = "error" ERROR = "error"
@dataclass class DownloadModelStatus(BaseModel):
class DownloadModelStatus(): status: DownloadStatusType
status: str progress_percentage: float = Field(ge=0, le=100)
progress_percentage: float
message: str message: str
already_existed: bool = False already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): class Config:
self.status = status.value # Store the string value of the Enum use_enum_values = True
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return self.model_dump()
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
@ -65,10 +55,10 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
""" """
if not validate_model_subdirectory(model_sub_directory): if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus( return DownloadModelStatus(
DownloadStatusType.ERROR, status=DownloadStatusType.ERROR,
0, progress_percentage=0,
"Invalid model subdirectory", message="Invalid model subdirectory",
False already_existed=False
) )
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
@ -77,16 +67,25 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
return existing_file return existing_file
try: try:
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) status = DownloadModelStatus(status=DownloadStatusType.PENDING,
progress_percentage=0,
message=f"Starting download of {model_name}",
already_existed=False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
response = await model_download_request(model_url) response = await model_download_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message) logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(status=DownloadStatusType.ERROR,
progress_percentage= 0,
message=error_message,
already_existed= False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return DownloadModelStatus(status=DownloadStatusType.ERROR,
progress_percentage=0,
message= error_message,
already_existed=False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
@ -107,7 +106,11 @@ async def check_file_exists(file_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]: relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) status = DownloadModelStatus(
status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message= f"{model_name} already exists",
already_existed=True)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return status return status
return None return None
@ -127,7 +130,10 @@ async def track_download_progress(response: aiohttp.ClientResponse,
async def update_progress(): async def update_progress():
nonlocal last_update_time nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0 progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) status = DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS,
progress_percentage=progress,
message=f"Downloading {model_name}",
already_existed=False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
last_update_time = time.time() last_update_time = time.time()
@ -147,7 +153,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
await update_progress() await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) status = DownloadModelStatus(
status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message=f"Successfully downloaded {model_name}",
already_existed=False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return status return status
@ -161,7 +171,10 @@ async def handle_download_error(e: Exception,
progress_callback: Callable[[str, DownloadModelStatus], Any], progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus: relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(status=DownloadStatusType.ERROR,
progress_percentage=0,
message=error_message,
already_existed=False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return status return status

View File

@ -13,6 +13,7 @@ Pillow
scipy scipy
tqdm tqdm
psutil psutil
pydantic~=2.8
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1

View File

@ -84,13 +84,19 @@ async def test_download_model_success():
# Check initial call # Check initial call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.bin', 'checkpoints/model.bin',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False) DownloadModelStatus(status=DownloadStatusType.PENDING,
progress_percentage= 0,
message="Starting download of model.bin",
already_existed= False)
) )
# Check final call # Check final call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.bin', 'checkpoints/model.bin',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False) DownloadModelStatus(status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message="Successfully downloaded model.bin",
already_existed= False)
) )
# Verify file writing # Verify file writing
@ -204,7 +210,10 @@ async def test_check_file_exists_when_file_exists(tmp_path):
mock_callback.assert_called_once_with( mock_callback.assert_called_once_with(
"test/existing_model.bin", "test/existing_model.bin",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True) DownloadModelStatus(status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message="existing_model.bin already exists",
already_existed=True)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -237,7 +246,7 @@ async def test_track_download_progress_no_content_length():
# Check that progress was reported even without knowing the total size # Check that progress was reported even without knowing the total size
mock_callback.assert_any_call( mock_callback.assert_any_call(
'models/model.bin', 'models/model.bin',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False) DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, progress_percentage= 0, message="Downloading model.bin", already_existed=False)
) )
@pytest.mark.asyncio @pytest.mark.asyncio