From b02ec058ff65a5c36ecf9dcbdb5b68dabf702735 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Wed, 7 Aug 2024 12:47:56 -0700 Subject: [PATCH] Add download_model test checking for invalid paths. --- .../download_models_test.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index c88cf958..abf8e3f7 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -149,6 +149,28 @@ async def test_download_model_url_request_failure(): # Verify that the get method was called with the correct URL mock_get.assert_called_once_with('http://example.com/model.safetensors') +@pytest.mark.asyncio +async def test_download_model_invalid_model_subdirectory(): + + mock_make_request = AsyncMock() + mock_progress_callback = AsyncMock() + + + result = await download_model( + mock_make_request, + 'model.bin', + 'http://example.com/model.bin', + '../bad_path', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelResult) + assert result.message == 'Invalid model subdirectory' + assert result.status == 'error' + assert result.already_existed is False + + # For create_model_path function def test_create_model_path(tmp_path, monkeypatch): mock_models_dir = tmp_path / "models"