from collections import deque from datetime import datetime import io import logging import sys import threading from rich.traceback import install from rich.console import Console from rich.theme import Theme from rich.logging import RichHandler logs = None stdout_interceptor = None stderr_interceptor = None from rich.traceback import install # Enable rich tracebacks globally install() # Configure rich console console = Console(force_terminal=True) # Set up handler rich_handler = RichHandler(console=console, rich_tracebacks=True, markup=True) # file_handler = logging.FileHandler("project.log") # Log to a file # file_handler.setLevel(log_level) from rich.traceback import install # Enable rich tracebacks globally install() class LogInterceptor(io.TextIOWrapper): def __init__(self, stream, *args, **kwargs): buffer = stream.buffer encoding = stream.encoding super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering) self._lock = threading.Lock() self._flush_callbacks = [] self._logs_since_flush = [] def write(self, data): entry = {"t": datetime.now().isoformat(), "m": data} with self._lock: self._logs_since_flush.append(entry) # Simple handling for cr to overwrite the last output if it isnt a full line # else logs just get full of progress messages if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"): logs.pop() logs.append(entry) super().write(data) def flush(self): super().flush() for cb in self._flush_callbacks: cb(self._logs_since_flush) self._logs_since_flush = [] def on_flush(self, callback): self._flush_callbacks.append(callback) def get_logs(): return logs def on_flush(callback): if stdout_interceptor is not None: stdout_interceptor.on_flush(callback) if stderr_interceptor is not None: stderr_interceptor.on_flush(callback) def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False): global logs if logs: return logging.basicConfig( level=log_level, format="%(message)s", datefmt="[%X]", handlers=[rich_handler] #file_handler ) # Override output streams and log to buffer logs = deque(maxlen=capacity) global stdout_interceptor global stderr_interceptor stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout) stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr) # Setup default global logger logger = logging.getLogger() logger.setLevel(log_level) stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter("%(message)s")) if use_stdout: # Only errors and critical to stderr stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR) # Lesser to stdout stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(logging.Formatter("%(message)s")) stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) logger.addHandler(stdout_handler) logger.addHandler(stream_handler) STARTUP_WARNINGS = [] def log_startup_warning(msg): logging.warning(msg) STARTUP_WARNINGS.append(msg) def print_startup_warnings(): for s in STARTUP_WARNINGS: logging.warning(s) STARTUP_WARNINGS.clear()