Add progress_interval as an optional parameter.

This commit is contained in:
Robin Huang 2024-08-07 17:04:26 -07:00
parent c1d78d6890
commit 5537d25e72
2 changed files with 18 additions and 13 deletions

View File

@ -4,7 +4,7 @@ import traceback
import logging import logging
from folder_paths import models_dir from folder_paths import models_dir
import re import re
from typing import Callable, Any, Optional, Awaitable, Tuple from typing import Callable, Any, Optional, Awaitable, Tuple, Dict
from enum import Enum from enum import Enum
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -27,12 +27,21 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage self.progress_percentage = progress_percentage
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_sub_directory: str, model_sub_directory: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus: progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
""" """
Download a model file from a given URL into the models directory. Download a model file from a given URL into the models directory.
@ -77,7 +86,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
except Exception as e: except Exception as e:
logging.error(f"Error in downloading model: {e}") logging.error(f"Error in downloading model: {e}")

View File

@ -27,7 +27,7 @@ import comfy.model_management
import node_helpers import node_helpers
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus from model_filemanager import download_model, DownloadModelStatus
from typing import Optional from typing import Optional
class BinaryEventTypes: class BinaryEventTypes:
@ -563,18 +563,14 @@ class PromptServer():
@routes.post("/models/download") @routes.post("/models/download")
async def download_handler(request): async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus): async def report_progress(filename: str, status: DownloadModelStatus):
await self.send_json("download_progress", { await self.send_json("download_progress", status.to_dict())
"filename": filename,
"progress_percentage": status.progress_percentage,
"status": status.status,
"message": status.message
})
data = await request.json() data = await request.json()
url = data.get('url') url = data.get('url')
model_directory = data.get('model_directory') model_directory = data.get('model_directory')
model_filename = data.get('model_filename') model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename: if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
@ -584,10 +580,10 @@ class PromptServer():
logging.error("Client session is not initialized") logging.error("Client session is not initialized")
return web.Response(status=500) return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress)) task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
await task await task
return web.Response(status=200) return web.json_response(task.result().to_dict())
async def setup(self): async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout timeout = aiohttp.ClientTimeout(total=None) # no timeout