Validate that model subdirectory cannot contain relative paths.

This commit is contained in:
Robin Huang 2024-08-07 12:45:07 -07:00
parent 9632dded9e
commit 59933489bf
4 changed files with 81 additions and 16 deletions

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory

View File

@ -3,7 +3,8 @@ import os
import traceback import traceback
import logging import logging
from folder_paths import models_dir from folder_paths import models_dir
from typing import Callable, Any, Optional, Awaitable import re
from typing import Callable, Any, Optional, Awaitable, Tuple
from enum import Enum from enum import Enum
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -36,13 +37,38 @@ class DownloadModelResult():
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_directory: str, model_sub_directory: str,
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult: progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
"""
Download a model file from a given URL into the models directory.
file_path, relative_path = create_model_path(model_name, model_directory, models_dir) Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelResult: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelResult(
DownloadStatusType.ERROR,
"Invalid model subdirectory",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file: if existing_file:
return existing_file return existing_file
@ -51,9 +77,10 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
response = await make_request(model_url) response = await model_download_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
@ -61,15 +88,11 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}") logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback, relative_path)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> Tuple[str, str]:
async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
return await session.get(url)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory) full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True) os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name) file_path = os.path.join(full_model_dir, model_name)
@ -126,4 +149,25 @@ async def handle_download_error(e: Exception, model_name: str, progress_callback
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True

View File

@ -561,7 +561,7 @@ class PromptServer():
return web.Response(status=200) return web.Response(status=200)
@routes.post("/download") @routes.post("/models/download")
async def download_handler(request): async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus): async def report_progress(filename: str, status: DownloadStatus):
await self.send_json("download_progress", { await self.send_json("download_progress", {

View File

@ -4,7 +4,7 @@ from aiohttp import ClientResponse
import itertools import itertools
import os import os
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
class AsyncIteratorMock: class AsyncIteratorMock:
""" """
@ -253,4 +253,25 @@ async def test_track_download_progress_interval():
last_call = mock_callback.call_args_list[-1] last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed" assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100 assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False