Validate that model subdirectory cannot contain relative paths.

This commit is contained in:
Robin Huang 2024-08-07 12:45:07 -07:00
parent 9632dded9e
commit 59933489bf
4 changed files with 81 additions and 16 deletions

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory

View File

@ -3,7 +3,8 @@ import os
import traceback
import logging
from folder_paths import models_dir
from typing import Callable, Any, Optional, Awaitable
import re
from typing import Callable, Any, Optional, Awaitable, Tuple
from enum import Enum
import time
from dataclasses import dataclass
@ -36,13 +37,38 @@ class DownloadModelResult():
self.message = message
self.already_existed = already_existed
async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_directory: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
"""
Download a model file from a given URL into the models directory.
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelResult: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelResult(
DownloadStatusType.ERROR,
"Invalid model subdirectory",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file:
return existing_file
@ -51,9 +77,10 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status)
response = await make_request(model_url)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
@ -61,15 +88,11 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
return await session.get(url)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> Tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
@ -127,3 +150,24 @@ async def handle_download_error(e: Exception, model_name: str, progress_callback
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True

View File

@ -561,7 +561,7 @@ class PromptServer():
return web.Response(status=200)
@routes.post("/download")
@routes.post("/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus):
await self.send_json("download_progress", {

View File

@ -4,7 +4,7 @@ from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
class AsyncIteratorMock:
"""
@ -254,3 +254,24 @@ async def test_track_download_progress_interval():
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False