mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-12 02:45:16 +00:00
68 lines
2.7 KiB
Python
68 lines
2.7 KiB
Python
import pytest
|
|
import aiohttp
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
|
|
|
|
|
|
async def async_iterator(chunks):
|
|
for chunk in chunks:
|
|
yield chunk
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_download_model_success():
|
|
# Create a temporary directory for testing
|
|
model_directory = str(uuid.uuid4())
|
|
|
|
# Create a mock session
|
|
session = AsyncMock(spec=aiohttp.ClientSession)
|
|
|
|
# Mock the response
|
|
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
|
mock_response.status = 200
|
|
mock_response.headers = {'Content-Length': '100'}
|
|
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
|
|
|
|
session.get.return_value.__aenter__.return_value = mock_response
|
|
|
|
# Create a mock progress callback
|
|
progress_callback = AsyncMock()
|
|
|
|
# Call the function
|
|
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
|
|
|
# Assert the expected behavior
|
|
assert result['status'] == DownloadStatusType.COMPLETED
|
|
assert result['message'] == 'Successfully downloaded model.safetensors'
|
|
assert result['already_existed'] is False
|
|
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
|
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_download_model_failure():
|
|
# Create a temporary directory for testing
|
|
model_directory = str(uuid.uuid4())
|
|
|
|
# Create a mock session
|
|
session = AsyncMock(spec=aiohttp.ClientSession)
|
|
|
|
# Mock the response with an error status code
|
|
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
|
mock_response.status = 500
|
|
session.get.return_value.__aenter__.return_value = mock_response
|
|
|
|
# Create a mock progress callback
|
|
progress_callback = AsyncMock()
|
|
|
|
# Call the function
|
|
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
|
print(result)
|
|
|
|
# Assert the expected behavior
|
|
assert result['status'] == DownloadStatusType.ERROR
|
|
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
|
|
assert result['already_existed'] is False
|
|
|
|
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
|
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500')) |