diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 63704f13..8893500a 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -2,6 +2,7 @@ from aiohttp import web from typing import Optional from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths from api_server.services.file_service import FileService +from api_server.services.terminal_service import TerminalService import app.logger class InternalRoutes: @@ -11,7 +12,8 @@ class InternalRoutes: Check README.md for more information. ''' - def __init__(self): + + def __init__(self, prompt_server): self.routes: web.RouteTableDef = web.RouteTableDef() self._app: Optional[web.Application] = None self.file_service = FileService({ @@ -19,6 +21,8 @@ class InternalRoutes: "user": user_directory, "output": output_directory }) + self.prompt_server = prompt_server + self.terminal_service = TerminalService(prompt_server) def setup_routes(self): @self.routes.get('/files') @@ -34,7 +38,28 @@ class InternalRoutes: @self.routes.get('/logs') async def get_logs(request): - return web.json_response(app.logger.get_logs()) + return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()])) + + @self.routes.get('/logs/raw') + async def get_logs(request): + self.terminal_service.update_size() + return web.json_response({ + "entries": list(app.logger.get_logs()), + "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows} + }) + + @self.routes.patch('/logs/subscribe') + async def subscribe_logs(request): + json_data = await request.json() + client_id = json_data["clientId"] + enabled = json_data["enabled"] + if enabled: + self.terminal_service.subscribe(client_id) + else: + self.terminal_service.unsubscribe(client_id) + + return web.Response(status=200) + @self.routes.get('/folder_paths') async def get_folder_paths(request): diff --git a/api_server/services/terminal_service.py b/api_server/services/terminal_service.py new file mode 100644 index 00000000..284afab5 --- /dev/null +++ b/api_server/services/terminal_service.py @@ -0,0 +1,47 @@ +from app.logger import on_flush +import os + + +class TerminalService: + def __init__(self, server): + self.server = server + self.cols = None + self.rows = None + self.subscriptions = set() + on_flush(self.send_messages) + + def update_size(self): + sz = os.get_terminal_size() + changed = False + if sz.columns != self.cols: + self.cols = sz.columns + changed = True + + if sz.lines != self.rows: + self.rows = sz.lines + changed = True + + if changed: + return {"cols": self.cols, "rows": self.rows} + + return None + + def subscribe(self, client_id): + self.subscriptions.add(client_id) + + def unsubscribe(self, client_id): + self.subscriptions.discard(client_id) + + def send_messages(self, entries): + if not len(entries) or not len(self.subscriptions): + return + + new_size = self.update_size() + + for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration + if client_id not in self.server.sockets: + # Automatically unsub if the socket has disconnected + self.unsubscribe(client_id) + continue + + self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id) diff --git a/app/logger.py b/app/logger.py index 4ca0ea88..527be9fe 100644 --- a/app/logger.py +++ b/app/logger.py @@ -1,20 +1,69 @@ -import logging -from logging.handlers import MemoryHandler from collections import deque +from datetime import datetime +import io +import logging +import sys +import threading logs = None -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +stdout_interceptor = None +stderr_interceptor = None + + +class LogInterceptor(io.TextIOWrapper): + def __init__(self, stream, *args, **kwargs): + buffer = stream.buffer + encoding = stream.encoding + super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering) + self._lock = threading.Lock() + self._flush_callbacks = [] + self._logs_since_flush = [] + + def write(self, data): + entry = {"t": datetime.now().isoformat(), "m": data} + with self._lock: + self._logs_since_flush.append(entry) + + # Simple handling for cr to overwrite the last output if it isnt a full line + # else logs just get full of progress messages + if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"): + logs.pop() + logs.append(entry) + super().write(data) + + def flush(self): + super().flush() + for cb in self._flush_callbacks: + cb(self._logs_since_flush) + self._logs_since_flush = [] + + def on_flush(self, callback): + self._flush_callbacks.append(callback) def get_logs(): - return "\n".join([formatter.format(x) for x in logs]) + return logs +def on_flush(callback): + if stdout_interceptor is not None: + stdout_interceptor.on_flush(callback) + if stderr_interceptor is not None: + stderr_interceptor.on_flush(callback) + def setup_logger(log_level: str = 'INFO', capacity: int = 300): global logs if logs: return + # Override output streams and log to buffer + logs = deque(maxlen=capacity) + + global stdout_interceptor + global stderr_interceptor + stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout) + stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr) + # Setup default global logger logger = logging.getLogger() logger.setLevel(log_level) @@ -22,10 +71,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300): stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter("%(message)s")) logger.addHandler(stream_handler) - - # Create a memory handler with a deque as its buffer - logs = deque(maxlen=capacity) - memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO) - memory_handler.buffer = logs - memory_handler.setFormatter(formatter) - logger.addHandler(memory_handler) diff --git a/server.py b/server.py index ada6d90c..e663095b 100644 --- a/server.py +++ b/server.py @@ -152,7 +152,7 @@ class PromptServer(): mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' self.user_manager = UserManager() - self.internal_routes = InternalRoutes() + self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop diff --git a/tests-unit/server/routes/internal_routes_test.py b/tests-unit/server/routes/internal_routes_test.py index 2d2b43bd..4fe54424 100644 --- a/tests-unit/server/routes/internal_routes_test.py +++ b/tests-unit/server/routes/internal_routes_test.py @@ -8,7 +8,7 @@ from folder_paths import models_dir, user_directory, output_directory @pytest.fixture def internal_routes(): - return InternalRoutes() + return InternalRoutes(None) @pytest.fixture def aiohttp_client_factory(aiohttp_client, internal_routes): @@ -102,7 +102,7 @@ async def test_file_service_initialization(): # Create a mock instance mock_file_service_instance = MagicMock(spec=FileService) MockFileService.return_value = mock_file_service_instance - internal_routes = InternalRoutes() + internal_routes = InternalRoutes(None) # Check if FileService was initialized with the correct parameters MockFileService.assert_called_once_with({ @@ -112,4 +112,4 @@ async def test_file_service_initialization(): }) # Verify that the file_service attribute of InternalRoutes is set - assert internal_routes.file_service == mock_file_service_instance \ No newline at end of file + assert internal_routes.file_service == mock_file_service_instance