2024-08-07 01:07:32 +00:00
|
|
|
import pytest
|
|
|
|
import aiohttp
|
2024-08-07 19:12:01 +00:00
|
|
|
from aiohttp import ClientResponse
|
|
|
|
import itertools
|
|
|
|
import os
|
|
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
class AsyncIteratorMock:
|
2024-08-07 19:17:29 +00:00
|
|
|
"""
|
|
|
|
A mock class that simulates an asynchronous iterator.
|
|
|
|
This is used to mimic the behavior of aiohttp's content iterator.
|
|
|
|
"""
|
2024-08-07 19:12:01 +00:00
|
|
|
def __init__(self, seq):
|
2024-08-07 19:17:29 +00:00
|
|
|
# Convert the input sequence into an iterator
|
2024-08-07 19:12:01 +00:00
|
|
|
self.iter = iter(seq)
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
def __aiter__(self):
|
2024-08-07 19:17:29 +00:00
|
|
|
# This method is called when 'async for' is used
|
2024-08-07 19:12:01 +00:00
|
|
|
return self
|
|
|
|
|
|
|
|
async def __anext__(self):
|
2024-08-07 19:17:29 +00:00
|
|
|
# This method is called for each iteration in an 'async for' loop
|
2024-08-07 19:12:01 +00:00
|
|
|
try:
|
|
|
|
return next(self.iter)
|
|
|
|
except StopIteration:
|
2024-08-07 19:17:29 +00:00
|
|
|
# This is the asynchronous equivalent of StopIteration
|
2024-08-07 19:12:01 +00:00
|
|
|
raise StopAsyncIteration
|
|
|
|
|
|
|
|
class ContentMock:
|
2024-08-07 19:17:29 +00:00
|
|
|
"""
|
|
|
|
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
|
|
|
This class provides the iter_chunked method which returns an async iterator of chunks.
|
|
|
|
"""
|
2024-08-07 19:12:01 +00:00
|
|
|
def __init__(self, chunks):
|
2024-08-07 19:17:29 +00:00
|
|
|
# Store the chunks that will be returned by the iterator
|
2024-08-07 19:12:01 +00:00
|
|
|
self.chunks = chunks
|
|
|
|
|
|
|
|
def iter_chunked(self, chunk_size):
|
2024-08-07 19:17:29 +00:00
|
|
|
# This method mimics aiohttp's content.iter_chunked()
|
|
|
|
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
2024-08-07 19:12:01 +00:00
|
|
|
return AsyncIteratorMock(self.chunks)
|
2024-08-07 01:07:32 +00:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_download_model_success():
|
2024-08-07 19:12:01 +00:00
|
|
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
2024-08-07 01:07:32 +00:00
|
|
|
mock_response.status = 200
|
2024-08-07 19:12:01 +00:00
|
|
|
mock_response.headers = {'Content-Length': '1000'}
|
|
|
|
# Create a mock for content that returns an async iterator directly
|
|
|
|
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
|
|
|
mock_response.content = ContentMock(chunks)
|
|
|
|
|
|
|
|
mock_make_request = AsyncMock(return_value=mock_response)
|
2024-08-07 19:17:29 +00:00
|
|
|
mock_progress_callback = AsyncMock()
|
2024-08-07 19:12:01 +00:00
|
|
|
|
|
|
|
# Mock file operations
|
|
|
|
mock_open = MagicMock()
|
|
|
|
mock_file = MagicMock()
|
|
|
|
mock_open.return_value.__enter__.return_value = mock_file
|
|
|
|
time_values = itertools.count(0, 0.1)
|
|
|
|
|
|
|
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \
|
|
|
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
|
|
|
patch('builtins.open', mock_open), \
|
|
|
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
|
|
|
|
|
|
|
result = await download_model(
|
|
|
|
mock_make_request,
|
|
|
|
'model.bin',
|
|
|
|
'http://example.com/model.bin',
|
|
|
|
'checkpoints',
|
|
|
|
mock_progress_callback
|
|
|
|
)
|
|
|
|
|
|
|
|
# Assert the result
|
|
|
|
assert isinstance(result, DownloadModelResult)
|
|
|
|
assert result.message == 'Successfully downloaded model.bin'
|
|
|
|
assert result.status == 'completed'
|
|
|
|
assert result.already_existed is False
|
|
|
|
|
|
|
|
# Check progress callback calls
|
|
|
|
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
# Check initial call
|
|
|
|
mock_progress_callback.assert_any_call(
|
|
|
|
'checkpoints/model.bin',
|
|
|
|
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin")
|
|
|
|
)
|
|
|
|
|
|
|
|
# Check final call
|
|
|
|
mock_progress_callback.assert_any_call(
|
|
|
|
'checkpoints/model.bin',
|
|
|
|
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin")
|
|
|
|
)
|
|
|
|
|
|
|
|
# Verify file writing
|
|
|
|
mock_file.write.assert_any_call(b'a' * 500)
|
|
|
|
mock_file.write.assert_any_call(b'b' * 300)
|
|
|
|
mock_file.write.assert_any_call(b'c' * 200)
|
|
|
|
|
|
|
|
# Verify request was made
|
|
|
|
mock_make_request.assert_called_once_with('http://example.com/model.bin')
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_download_model_url_request_failure():
|
|
|
|
# Mock dependencies
|
|
|
|
mock_response = AsyncMock(spec=ClientResponse)
|
|
|
|
mock_response.status = 404 # Simulate a "Not Found" error
|
|
|
|
mock_get = AsyncMock(return_value=mock_response)
|
|
|
|
mock_progress_callback = AsyncMock()
|
|
|
|
|
|
|
|
# Mock the create_model_path function
|
|
|
|
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
|
|
|
# Mock the check_file_exists function to return None (file doesn't exist)
|
|
|
|
with patch('model_filemanager.check_file_exists', return_value=None):
|
|
|
|
# Call the function
|
|
|
|
result = await download_model(
|
|
|
|
mock_get,
|
|
|
|
'model.safetensors',
|
|
|
|
'http://example.com/model.safetensors',
|
|
|
|
'mock_directory',
|
|
|
|
mock_progress_callback
|
|
|
|
)
|
|
|
|
|
|
|
|
# Assert the expected behavior
|
|
|
|
assert isinstance(result, DownloadModelResult)
|
|
|
|
assert result.status == 'error'
|
|
|
|
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
|
|
|
assert result.already_existed is False
|
|
|
|
|
|
|
|
# Check that progress_callback was called with the correct arguments
|
|
|
|
mock_progress_callback.assert_any_call(
|
|
|
|
'mock_directory/model.safetensors',
|
|
|
|
DownloadStatus(
|
|
|
|
status=DownloadStatusType.PENDING,
|
|
|
|
progress_percentage=0,
|
|
|
|
message='Starting download of model.safetensors'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
mock_progress_callback.assert_called_with(
|
|
|
|
'mock_directory/model.safetensors',
|
|
|
|
DownloadStatus(
|
|
|
|
status=DownloadStatusType.ERROR,
|
|
|
|
progress_percentage=0,
|
|
|
|
message='Failed to download model.safetensors. Status code: 404'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Verify that the get method was called with the correct URL
|
|
|
|
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
|
|
|
|
|
|
|
# For create_model_path function
|
|
|
|
def test_create_model_path(tmp_path, monkeypatch):
|
|
|
|
mock_models_dir = tmp_path / "models"
|
|
|
|
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
model_name = "test_model.bin"
|
|
|
|
model_directory = "test_dir"
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
assert file_path == str(mock_models_dir / model_directory / model_name)
|
|
|
|
assert relative_path == f"{model_directory}/{model_name}"
|
|
|
|
assert os.path.exists(os.path.dirname(file_path))
|
|
|
|
|
2024-08-07 01:07:32 +00:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-08-07 19:12:01 +00:00
|
|
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
|
|
|
file_path = tmp_path / "existing_model.bin"
|
|
|
|
file_path.touch() # Create an empty file
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
mock_callback = AsyncMock()
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin")
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
assert result is not None
|
|
|
|
assert result.status == "completed"
|
|
|
|
assert result.message == "existing_model.bin already exists"
|
|
|
|
assert result.already_existed is True
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
mock_callback.assert_called_once_with(
|
|
|
|
"test/existing_model.bin",
|
|
|
|
DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists")
|
|
|
|
)
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
|
|
|
file_path = tmp_path / "non_existing_model.bin"
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
mock_callback = AsyncMock()
|
2024-08-07 01:07:32 +00:00
|
|
|
|
2024-08-07 19:12:01 +00:00
|
|
|
result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin")
|
|
|
|
|
|
|
|
assert result is None
|
|
|
|
mock_callback.assert_not_called()
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_track_download_progress_no_content_length():
|
|
|
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
|
|
|
mock_response.headers = {} # No Content-Length header
|
|
|
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
|
|
|
|
|
|
|
mock_callback = AsyncMock()
|
|
|
|
mock_open = MagicMock(return_value=MagicMock())
|
|
|
|
|
|
|
|
with patch('builtins.open', mock_open):
|
|
|
|
result = await track_download_progress(
|
|
|
|
mock_response, '/mock/path/model.bin', 'model.bin',
|
|
|
|
mock_callback, 'models/model.bin', interval=0.1
|
|
|
|
)
|
|
|
|
|
|
|
|
assert result.status == "completed"
|
|
|
|
# Check that progress was reported even without knowing the total size
|
|
|
|
mock_callback.assert_any_call(
|
|
|
|
'models/model.bin',
|
|
|
|
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin")
|
|
|
|
)
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_track_download_progress_interval():
|
|
|
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
|
|
|
mock_response.headers = {'Content-Length': '1000'}
|
|
|
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
|
|
|
|
|
|
|
mock_callback = AsyncMock()
|
|
|
|
mock_open = MagicMock(return_value=MagicMock())
|
|
|
|
|
|
|
|
# Create a mock time function that returns incremental float values
|
|
|
|
mock_time = MagicMock()
|
|
|
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
|
|
|
|
|
|
|
with patch('builtins.open', mock_open), \
|
|
|
|
patch('time.time', mock_time):
|
|
|
|
await track_download_progress(
|
|
|
|
mock_response, '/mock/path/model.bin', 'model.bin',
|
|
|
|
mock_callback, 'models/model.bin', interval=1.0
|
|
|
|
)
|
|
|
|
|
|
|
|
# Print out the actual call count and the arguments of each call for debugging
|
|
|
|
print(f"mock_callback was called {mock_callback.call_count} times")
|
|
|
|
for i, call in enumerate(mock_callback.call_args_list):
|
|
|
|
args, kwargs = call
|
|
|
|
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
|
|
|
|
|
|
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
|
|
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
|
|
|
|
|
|
|
# Verify the first and last calls
|
|
|
|
first_call = mock_callback.call_args_list[0]
|
|
|
|
assert first_call[0][1].status == "in_progress"
|
|
|
|
# Allow for some initial progress, but it should be less than 50%
|
|
|
|
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
|
|
|
|
|
|
|
last_call = mock_callback.call_args_list[-1]
|
|
|
|
assert last_call[0][1].status == "completed"
|
|
|
|
assert last_call[0][1].progress_percentage == 100
|