mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Validate that model subdirectory cannot contain relative paths.
This commit is contained in:
parent
9632dded9e
commit
59933489bf
@ -1,2 +1,2 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress
|
||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory
|
||||
|
@ -3,7 +3,8 @@ import os
|
||||
import traceback
|
||||
import logging
|
||||
from folder_paths import models_dir
|
||||
from typing import Callable, Any, Optional, Awaitable
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@ -36,13 +37,38 @@ class DownloadModelResult():
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_directory: str,
|
||||
model_sub_directory: str,
|
||||
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
|
||||
"""
|
||||
Download a model file from a given URL into the models directory.
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
|
||||
Args:
|
||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||
model_name (str):
|
||||
The name of the model file to be downloaded. This will be the filename on disk.
|
||||
model_url (str):
|
||||
The URL from which to download the model.
|
||||
model_sub_directory (str):
|
||||
The subdirectory within the main models directory where the model
|
||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
|
||||
An asynchronous function to call with progress updates.
|
||||
|
||||
Returns:
|
||||
DownloadModelResult: The result of the download operation.
|
||||
"""
|
||||
if not validate_model_subdirectory(model_sub_directory):
|
||||
return DownloadModelResult(
|
||||
DownloadStatusType.ERROR,
|
||||
"Invalid model subdirectory",
|
||||
False
|
||||
)
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||
if existing_file:
|
||||
return existing_file
|
||||
@ -51,9 +77,10 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
|
||||
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
response = await make_request(model_url)
|
||||
response = await model_download_request(model_url)
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
logging.error(error_message)
|
||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||
await progress_callback(relative_path, status)
|
||||
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
||||
@ -61,15 +88,11 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
logging.error(f"Error in downloading model: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
|
||||
|
||||
async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
|
||||
return await session.get(url)
|
||||
|
||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> Tuple[str, str]:
|
||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||
os.makedirs(full_model_dir, exist_ok=True)
|
||||
file_path = os.path.join(full_model_dir, model_name)
|
||||
@ -127,3 +150,24 @@ async def handle_download_error(e: Exception, model_name: str, progress_callback
|
||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||
await progress_callback(relative_path, status)
|
||||
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
||||
|
||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||
"""
|
||||
Validate that the model subdirectory is safe.
|
||||
|
||||
Args:
|
||||
model_subdirectory (str): The subdirectory for the specific model type.
|
||||
|
||||
Returns:
|
||||
bool: True if the subdirectory is safe, False otherwise.
|
||||
"""
|
||||
if len(model_subdirectory) > 50:
|
||||
return False
|
||||
|
||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
||||
return False
|
||||
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
||||
return False
|
||||
|
||||
return True
|
@ -561,7 +561,7 @@ class PromptServer():
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/download")
|
||||
@routes.post("/models/download")
|
||||
async def download_handler(request):
|
||||
async def report_progress(filename: str, status: DownloadStatus):
|
||||
await self.send_json("download_progress", {
|
||||
|
@ -4,7 +4,7 @@ from aiohttp import ClientResponse
|
||||
import itertools
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
||||
|
||||
class AsyncIteratorMock:
|
||||
"""
|
||||
@ -254,3 +254,24 @@ async def test_track_download_progress_interval():
|
||||
last_call = mock_callback.call_args_list[-1]
|
||||
assert last_call[0][1].status == "completed"
|
||||
assert last_call[0][1].progress_percentage == 100
|
||||
|
||||
def test_valid_subdirectory():
|
||||
assert validate_model_subdirectory("valid-model123") is True
|
||||
|
||||
def test_subdirectory_too_long():
|
||||
assert validate_model_subdirectory("a" * 51) is False
|
||||
|
||||
def test_subdirectory_with_double_dots():
|
||||
assert validate_model_subdirectory("model/../unsafe") is False
|
||||
|
||||
def test_subdirectory_with_slash():
|
||||
assert validate_model_subdirectory("model/unsafe") is False
|
||||
|
||||
def test_subdirectory_with_special_characters():
|
||||
assert validate_model_subdirectory("model@unsafe") is False
|
||||
|
||||
def test_subdirectory_with_underscore_and_dash():
|
||||
assert validate_model_subdirectory("valid_model-name") is True
|
||||
|
||||
def test_empty_subdirectory():
|
||||
assert validate_model_subdirectory("") is False
|
Loading…
Reference in New Issue
Block a user