diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index 5c71ae52..16b2165e 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress +from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index ed0c5a52..d4eb1173 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -3,7 +3,8 @@ import os import traceback import logging from folder_paths import models_dir -from typing import Callable, Any, Optional, Awaitable +import re +from typing import Callable, Any, Optional, Awaitable, Tuple from enum import Enum import time from dataclasses import dataclass @@ -36,13 +37,38 @@ class DownloadModelResult(): self.message = message self.already_existed = already_existed -async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], +async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, model_url: str, - model_directory: str, + model_sub_directory: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult: + """ + Download a model file from a given URL into the models directory. - file_path, relative_path = create_model_path(model_name, model_directory, models_dir) + Args: + model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): + A function that makes an HTTP request. This makes it easier to mock in unit tests. + model_name (str): + The name of the model file to be downloaded. This will be the filename on disk. + model_url (str): + The URL from which to download the model. + model_sub_directory (str): + The subdirectory within the main models directory where the model + should be saved (e.g., 'checkpoints', 'loras', etc.). + progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]): + An asynchronous function to call with progress updates. + + Returns: + DownloadModelResult: The result of the download operation. + """ + if not validate_model_subdirectory(model_sub_directory): + return DownloadModelResult( + DownloadStatusType.ERROR, + "Invalid model subdirectory", + False + ) + + file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) if existing_file: return existing_file @@ -51,9 +77,10 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") await progress_callback(relative_path, status) - response = await make_request(model_url) + response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" + logging.error(error_message) status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) await progress_callback(relative_path, status) return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) @@ -61,15 +88,11 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) except Exception as e: - logging.error(f"Error in track_download_progress: {e}") + logging.error(f"Error in downloading model: {e}") return await handle_download_error(e, model_name, progress_callback, relative_path) - -async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse: - return await session.get(url) - -def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: +def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> Tuple[str, str]: full_model_dir = os.path.join(models_base_dir, model_directory) os.makedirs(full_model_dir, exist_ok=True) file_path = os.path.join(full_model_dir, model_name) @@ -126,4 +149,25 @@ async def handle_download_error(e: Exception, model_name: str, progress_callback error_message = f"Error downloading {model_name}: {str(e)}" status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) await progress_callback(relative_path, status) - return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) \ No newline at end of file + return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) + +def validate_model_subdirectory(model_subdirectory: str) -> bool: + """ + Validate that the model subdirectory is safe. + + Args: + model_subdirectory (str): The subdirectory for the specific model type. + + Returns: + bool: True if the subdirectory is safe, False otherwise. + """ + if len(model_subdirectory) > 50: + return False + + if '..' in model_subdirectory or '/' in model_subdirectory: + return False + + if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory): + return False + + return True \ No newline at end of file diff --git a/server.py b/server.py index db34c998..79103318 100644 --- a/server.py +++ b/server.py @@ -561,7 +561,7 @@ class PromptServer(): return web.Response(status=200) - @routes.post("/download") + @routes.post("/models/download") async def download_handler(request): async def report_progress(filename: str, status: DownloadStatus): await self.send_json("download_progress", { diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 26142289..c88cf958 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -4,7 +4,7 @@ from aiohttp import ClientResponse import itertools import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType +from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType class AsyncIteratorMock: """ @@ -253,4 +253,25 @@ async def test_track_download_progress_interval(): last_call = mock_callback.call_args_list[-1] assert last_call[0][1].status == "completed" - assert last_call[0][1].progress_percentage == 100 \ No newline at end of file + assert last_call[0][1].progress_percentage == 100 + +def test_valid_subdirectory(): + assert validate_model_subdirectory("valid-model123") is True + +def test_subdirectory_too_long(): + assert validate_model_subdirectory("a" * 51) is False + +def test_subdirectory_with_double_dots(): + assert validate_model_subdirectory("model/../unsafe") is False + +def test_subdirectory_with_slash(): + assert validate_model_subdirectory("model/unsafe") is False + +def test_subdirectory_with_special_characters(): + assert validate_model_subdirectory("model@unsafe") is False + +def test_subdirectory_with_underscore_and_dash(): + assert validate_model_subdirectory("valid_model-name") is True + +def test_empty_subdirectory(): + assert validate_model_subdirectory("") is False \ No newline at end of file