
277 lines
11 KiB
Raw Normal View History

2024-08-07 01:07:32 +00:00
import pytest
import aiohttp
2024-08-07 19:12:01 +00:00
from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
class AsyncIteratorMock:
2024-08-07 19:17:29 +00:00
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
2024-08-07 19:12:01 +00:00
def __init__(self, seq):
2024-08-07 19:17:29 +00:00
# Convert the input sequence into an iterator
2024-08-07 19:12:01 +00:00
self.iter = iter(seq)
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
def __aiter__(self):
2024-08-07 19:17:29 +00:00
# This method is called when 'async for' is used
2024-08-07 19:12:01 +00:00
return self
async def __anext__(self):
2024-08-07 19:17:29 +00:00
# This method is called for each iteration in an 'async for' loop
2024-08-07 19:12:01 +00:00
return next(self.iter)
except StopIteration:
2024-08-07 19:17:29 +00:00
# This is the asynchronous equivalent of StopIteration
2024-08-07 19:12:01 +00:00
raise StopAsyncIteration
class ContentMock:
2024-08-07 19:17:29 +00:00
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
2024-08-07 19:12:01 +00:00
def __init__(self, chunks):
2024-08-07 19:17:29 +00:00
# Store the chunks that will be returned by the iterator
2024-08-07 19:12:01 +00:00
self.chunks = chunks
def iter_chunked(self, chunk_size):
2024-08-07 19:17:29 +00:00
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
2024-08-07 19:12:01 +00:00
return AsyncIteratorMock(self.chunks)
2024-08-07 01:07:32 +00:00
async def test_download_model_success():
2024-08-07 19:12:01 +00:00
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
2024-08-07 01:07:32 +00:00
mock_response.status = 200
2024-08-07 19:12:01 +00:00
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
2024-08-07 19:17:29 +00:00
mock_progress_callback = AsyncMock()
2024-08-07 19:12:01 +00:00
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.bin', 'checkpoints/model.bin')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
# Assert the result
assert isinstance(result, DownloadModelResult)
assert result.message == 'Successfully downloaded model.bin'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
# Check initial call
DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin")
# Check final call
DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin")
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
# Verify request was made
async def test_download_model_url_request_failure():
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
# Assert the expected behavior
assert isinstance(result, DownloadModelResult)
assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
message='Starting download of model.safetensors'
message='Failed to download model.safetensors. Status code: 404'
# Verify that the get method was called with the correct URL
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
model_name = "test_model.bin"
model_directory = "test_dir"
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path))
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.bin"
file_path.touch() # Create an empty file
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
mock_callback = AsyncMock()
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
result = await check_file_exists(str(file_path), "existing_model.bin", mock_callback, "test/existing_model.bin")
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.bin already exists"
assert result.already_existed is True
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists")
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.bin"
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
mock_callback = AsyncMock()
2024-08-07 01:07:32 +00:00
2024-08-07 19:12:01 +00:00
result = await check_file_exists(str(file_path), "non_existing_model.bin", mock_callback, "test/non_existing_model.bin")
assert result is None
async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=0.1
assert result.status == "completed"
# Check that progress was reported even without knowing the total size
DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin")
async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.bin', 'model.bin',
mock_callback, 'models/model.bin', interval=1.0
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False