mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-01 12:57:11 +08:00
Add progress_interval as an optional parameter.
This commit is contained in:
parent
c1d78d6890
commit
5537d25e72
@ -4,7 +4,7 @@ import traceback
|
|||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import models_dir
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Tuple
|
from typing import Callable, Any, Optional, Awaitable, Tuple, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -27,12 +27,21 @@ class DownloadModelStatus():
|
|||||||
self.progress_percentage = progress_percentage
|
self.progress_percentage = progress_percentage
|
||||||
self.message = message
|
self.message = message
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"status": self.status,
|
||||||
|
"progress_percentage": self.progress_percentage,
|
||||||
|
"message": self.message,
|
||||||
|
"already_existed": self.already_existed
|
||||||
|
}
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_sub_directory: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus:
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
Download a model file from a given URL into the models directory.
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
@ -77,7 +86,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
|
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
|
16
server.py
16
server.py
@ -27,7 +27,7 @@ import comfy.model_management
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from model_filemanager import download_model, DownloadStatus
|
from model_filemanager import download_model, DownloadModelStatus
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
@ -563,18 +563,14 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.post("/models/download")
|
@routes.post("/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
await self.send_json("download_progress", {
|
await self.send_json("download_progress", status.to_dict())
|
||||||
"filename": filename,
|
|
||||||
"progress_percentage": status.progress_percentage,
|
|
||||||
"status": status.status,
|
|
||||||
"message": status.message
|
|
||||||
})
|
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
@ -584,10 +580,10 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.json_response(task.result().to_dict())
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
|
Loading…
x
Reference in New Issue
Block a user