mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Consolidate DownloadStatus and DownloadModelResult
This commit is contained in:
parent
a6d8a93fa1
commit
c1d78d6890
@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory
|
||||||
|
@ -16,32 +16,23 @@ class DownloadStatusType(Enum):
|
|||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
progress_percentage: float
|
progress_percentage: float
|
||||||
message: str
|
message: str
|
||||||
|
already_existed: bool = False
|
||||||
|
|
||||||
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str):
|
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
||||||
self.status = status.value # Store the string value of the Enum
|
self.status = status.value # Store the string value of the Enum
|
||||||
self.progress_percentage = progress_percentage
|
self.progress_percentage = progress_percentage
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DownloadModelResult():
|
|
||||||
status: str
|
|
||||||
message: str
|
|
||||||
already_existed: bool
|
|
||||||
|
|
||||||
def __init__(self, status: DownloadStatusType, message: str, already_existed: bool):
|
|
||||||
self.status = status.value # Store the string value of the Enum
|
|
||||||
self.message = message
|
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_sub_directory: str,
|
||||||
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
Download a model file from a given URL into the models directory.
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
@ -55,15 +46,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
model_sub_directory (str):
|
model_sub_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelResult: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
if not validate_model_subdirectory(model_sub_directory):
|
||||||
return DownloadModelResult(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
"Invalid model subdirectory",
|
"Invalid model subdirectory",
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
@ -74,16 +66,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
@ -99,15 +91,23 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
|
|||||||
relative_path = '/'.join([model_directory, model_name])
|
relative_path = '/'.join([model_directory, model_name])
|
||||||
return file_path, relative_path
|
return file_path, relative_path
|
||||||
|
|
||||||
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]:
|
async def check_file_exists(file_path: str,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelResult(DownloadStatusType.COMPLETED, f"{model_name} already exists", True)
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
|
file_path: str,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
relative_path: str,
|
||||||
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
downloaded = 0
|
downloaded = 0
|
||||||
@ -116,7 +116,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
|
|||||||
async def update_progress():
|
async def update_progress():
|
||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
@ -136,20 +136,23 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
|
|||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
|
async def handle_download_error(e: Exception,
|
||||||
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||||
|
relative_path: str) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -4,7 +4,7 @@ from aiohttp import ClientResponse
|
|||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@ -73,7 +73,7 @@ async def test_download_model_success():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelResult)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Successfully downloaded model.bin'
|
assert result.message == 'Successfully downloaded model.bin'
|
||||||
assert result.status == 'completed'
|
assert result.status == 'completed'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
@ -84,13 +84,13 @@ async def test_download_model_success():
|
|||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.bin',
|
'checkpoints/model.bin',
|
||||||
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin")
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.bin',
|
'checkpoints/model.bin',
|
||||||
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin")
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
# Verify file writing
|
||||||
@ -123,7 +123,7 @@ async def test_download_model_url_request_failure():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert the expected behavior
|
# Assert the expected behavior
|
||||||
assert isinstance(result, DownloadModelResult)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
@ -131,18 +131,20 @@ async def test_download_model_url_request_failure():
|
|||||||
# Check that progress_callback was called with the correct arguments
|
# Check that progress_callback was called with the correct arguments
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'mock_directory/model.safetensors',
|
'mock_directory/model.safetensors',
|
||||||
DownloadStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.PENDING,
|
status=DownloadStatusType.PENDING,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
message='Starting download of model.safetensors'
|
message='Starting download of model.safetensors',
|
||||||
|
already_existed=False
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mock_progress_callback.assert_called_with(
|
mock_progress_callback.assert_called_with(
|
||||||
'mock_directory/model.safetensors',
|
'mock_directory/model.safetensors',
|
||||||
DownloadStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
message='Failed to download model.safetensors. Status code: 404'
|
message='Failed to download model.safetensors. Status code: 404',
|
||||||
|
already_existed=False
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -165,7 +167,7 @@ async def test_download_model_invalid_model_subdirectory():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelResult)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Invalid model subdirectory'
|
assert result.message == 'Invalid model subdirectory'
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
@ -202,7 +204,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):
|
|||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.bin",
|
"test/existing_model.bin",
|
||||||
DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists")
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -235,7 +237,7 @@ async def test_track_download_progress_no_content_length():
|
|||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.bin',
|
'models/model.bin',
|
||||||
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin")
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user