mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Validate that model subdirectory cannot contain relative paths.
This commit is contained in:
parent
9632dded9e
commit
59933489bf
@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress
|
from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory
|
||||||
|
@ -3,7 +3,8 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import models_dir
|
||||||
from typing import Callable, Any, Optional, Awaitable
|
import re
|
||||||
|
from typing import Callable, Any, Optional, Awaitable, Tuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -36,13 +37,38 @@ class DownloadModelResult():
|
|||||||
self.message = message
|
self.message = message
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_directory: str,
|
model_sub_directory: str,
|
||||||
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
|
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
|
||||||
|
"""
|
||||||
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
|
Args:
|
||||||
|
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||||
|
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||||
|
model_name (str):
|
||||||
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
|
model_url (str):
|
||||||
|
The URL from which to download the model.
|
||||||
|
model_sub_directory (str):
|
||||||
|
The subdirectory within the main models directory where the model
|
||||||
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
|
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
|
||||||
|
An asynchronous function to call with progress updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DownloadModelResult: The result of the download operation.
|
||||||
|
"""
|
||||||
|
if not validate_model_subdirectory(model_sub_directory):
|
||||||
|
return DownloadModelResult(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
"Invalid model subdirectory",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
@ -51,9 +77,10 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
|
|||||||
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
|
|
||||||
response = await make_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
|
logging.error(error_message)
|
||||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
||||||
@ -61,15 +88,11 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
|
|||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> Tuple[str, str]:
|
||||||
async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
|
|
||||||
return await session.get(url)
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
os.makedirs(full_model_dir, exist_ok=True)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
file_path = os.path.join(full_model_dir, model_name)
|
||||||
@ -126,4 +149,25 @@ async def handle_download_error(e: Exception, model_name: str, progress_callback
|
|||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
|
||||||
|
|
||||||
|
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the model subdirectory is safe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_subdirectory (str): The subdirectory for the specific model type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the subdirectory is safe, False otherwise.
|
||||||
|
"""
|
||||||
|
if len(model_subdirectory) > 50:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if '..' in model_subdirectory or '/' in model_subdirectory:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
@ -561,7 +561,7 @@ class PromptServer():
|
|||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
@routes.post("/download")
|
@routes.post("/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadStatus):
|
async def report_progress(filename: str, status: DownloadStatus):
|
||||||
await self.send_json("download_progress", {
|
await self.send_json("download_progress", {
|
||||||
|
@ -4,7 +4,7 @@ from aiohttp import ClientResponse
|
|||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@ -253,4 +253,25 @@ async def test_track_download_progress_interval():
|
|||||||
|
|
||||||
last_call = mock_callback.call_args_list[-1]
|
last_call = mock_callback.call_args_list[-1]
|
||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
|
def test_valid_subdirectory():
|
||||||
|
assert validate_model_subdirectory("valid-model123") is True
|
||||||
|
|
||||||
|
def test_subdirectory_too_long():
|
||||||
|
assert validate_model_subdirectory("a" * 51) is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_double_dots():
|
||||||
|
assert validate_model_subdirectory("model/../unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_slash():
|
||||||
|
assert validate_model_subdirectory("model/unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_special_characters():
|
||||||
|
assert validate_model_subdirectory("model@unsafe") is False
|
||||||
|
|
||||||
|
def test_subdirectory_with_underscore_and_dash():
|
||||||
|
assert validate_model_subdirectory("valid_model-name") is True
|
||||||
|
|
||||||
|
def test_empty_subdirectory():
|
||||||
|
assert validate_model_subdirectory("") is False
|
Loading…
Reference in New Issue
Block a user