diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 003fcb71..eb55f5d9 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -4,7 +4,7 @@ import traceback import logging from folder_paths import models_dir import re -from typing import Callable, Any, Optional, Awaitable, Tuple +from typing import Callable, Any, Optional, Awaitable, Tuple, Dict from enum import Enum import time from dataclasses import dataclass @@ -27,12 +27,21 @@ class DownloadModelStatus(): self.progress_percentage = progress_percentage self.message = message self.already_existed = already_existed + + def to_dict(self) -> Dict[str, Any]: + return { + "status": self.status, + "progress_percentage": self.progress_percentage, + "message": self.message, + "already_existed": self.already_existed + } async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, model_url: str, model_sub_directory: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus: + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + progress_interval: float = 1.0) -> DownloadModelStatus: """ Download a model file from a given URL into the models directory. @@ -77,7 +86,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht await progress_callback(relative_path, status) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) + return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) except Exception as e: logging.error(f"Error in downloading model: {e}") diff --git a/server.py b/server.py index 79103318..c52db621 100644 --- a/server.py +++ b/server.py @@ -27,7 +27,7 @@ import comfy.model_management import node_helpers from app.frontend_management import FrontendManager from app.user_manager import UserManager -from model_filemanager import download_model, DownloadStatus +from model_filemanager import download_model, DownloadModelStatus from typing import Optional class BinaryEventTypes: @@ -563,18 +563,14 @@ class PromptServer(): @routes.post("/models/download") async def download_handler(request): - async def report_progress(filename: str, status: DownloadStatus): - await self.send_json("download_progress", { - "filename": filename, - "progress_percentage": status.progress_percentage, - "status": status.status, - "message": status.message - }) + async def report_progress(filename: str, status: DownloadModelStatus): + await self.send_json("download_progress", status.to_dict()) data = await request.json() url = data.get('url') model_directory = data.get('model_directory') model_filename = data.get('model_filename') + progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. if not url or not model_directory or not model_filename: return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) @@ -584,10 +580,10 @@ class PromptServer(): logging.error("Client session is not initialized") return web.Response(status=500) - task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress)) + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) await task - return web.Response(status=200) + return web.json_response(task.result().to_dict()) async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout