diff --git a/main.py b/main.py index 666193b6..ba204059 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import json import threading import heapq import traceback +import asyncio if __name__ == "__main__": if '--help' in sys.argv: @@ -25,7 +26,6 @@ if '--dont-upcast-attention' in sys.argv: os.environ['ATTN_PRECISION'] = "fp16" import torch - import nodes def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): @@ -286,16 +286,19 @@ def prompt_worker(q): q.task_done(item_id) class PromptQueue: - def __init__(self): + def __init__(self, socket_handler): + self.socket_handler = socket_handler self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) self.task_counter = 0 self.queue = [] self.currently_running = {} + socket_handler.prompt_queue = self def put(self, item): with self.mutex: heapq.heappush(self.queue, item) + self.socket_handler.queue_updated(self) self.not_empty.notify() def get(self): @@ -306,11 +309,13 @@ class PromptQueue: i = self.task_counter self.currently_running[i] = copy.deepcopy(item) self.task_counter += 1 + self.socket_handler.queue_updated(self) return (item, i) def task_done(self, item_id): with self.mutex: self.currently_running.pop(item_id) + self.socket_handler.queue_updated(self) def get_current_queue(self): with self.mutex: @@ -326,6 +331,7 @@ class PromptQueue: def wipe_queue(self): with self.mutex: self.queue = [] + self.socket_handler.queue_updated(self) def delete_queue_item(self, function): with self.mutex: @@ -336,35 +342,82 @@ class PromptQueue: else: self.queue.pop(x) heapq.heapify(self.queue) + self.socket_handler.queue_updated(self) return True return False -from http.server import BaseHTTPRequestHandler, HTTPServer +import aiohttp +from aiohttp import web -class PromptServer(BaseHTTPRequestHandler): - def _set_headers(self, code=200, ct='text/html'): - self.send_response(code) - self.send_header('Content-type', ct) - self.end_headers() - def log_message(self, format, *args): - pass - def do_GET(self): - if self.path == "/prompt": - self._set_headers(ct='application/json') - prompt_info = {} - exec_info = {} - exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining() - prompt_info['exec_info'] = exec_info - self.wfile.write(json.dumps(prompt_info).encode('utf-8')) - elif self.path == "/queue": - self._set_headers(ct='application/json') - queue_info = {} - current_queue = self.server.prompt_queue.get_current_queue() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] - self.wfile.write(json.dumps(queue_info).encode('utf-8')) - elif self.path == "/object_info": - self._set_headers(ct='application/json') +def get_queue_info(prompt_queue): + prompt_info = {} + exec_info = {} + exec_info['queue_remaining'] = prompt_queue.get_tasks_remaining() + prompt_info['exec_info'] = exec_info + return prompt_info + +class SocketHandler(): + def __init__(self, loop): + self.connected = set() + self.messages = asyncio.Queue() + self.loop = loop + + async def publish_loop(self): + while True: + msg = await self.messages.get() + await self.send(msg) + + def queue_updated(self, queue): + # This is called by the queue processing thread so we need to make it thread safe + loop.call_soon_threadsafe(self.messages.put_nowait, { 'type': 'status', 'status': get_queue_info(queue) }) + + async def send(self, message, socket = None): + if isinstance(message, str) == False: + message = json.dumps(message) + + if socket is None: + for ws in self.connected: + await ws.send_str(message) + else: + await socket.send_str(message) + + async def process(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) + self.connected.add(ws) + try: + # Send initial state to the new client + await self.send({ 'type': 'status', 'status': get_queue_info(self.prompt_queue) }, ws) + async for msg in ws: + if msg.type == aiohttp.WSMsgType.ERROR: + print('ws connection closed with exception %s' % ws.exception()) + finally: + self.connected.remove(ws) + + return ws + +class PromptServer(): + def __init__(self, prompt_queue, socket_handler): + self.prompt_queue = prompt_queue + self.socket_handler = socket_handler + self.number = 0 + self.app = web.Application() + routes = web.RouteTableDef() + + @routes.get('/ws') + async def websocket_handler(request): + return await self.socket_handler.process(request) + + @routes.get("/") + async def get_root(request): + return aiohttp.web.HTTPFound('/index.html') + + @routes.get("/prompt") + async def get_prompt(request): + return web.json_response(get_queue_info(self.prompt_queue)) + + @routes.get("/object_info") + async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: obj_class = nodes.NODE_CLASS_MAPPINGS[x] @@ -377,40 +430,32 @@ class PromptServer(BaseHTTPRequestHandler): if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY out[x] = info - self.wfile.write(json.dumps(out).encode('utf-8')) - elif self.path[1:] in os.listdir(self.server.server_dir): - if self.path[1:].endswith('.css'): - self._set_headers(ct='text/css') - elif self.path[1:].endswith('.js'): - self._set_headers(ct='text/javascript') - else: - self._set_headers() - with open(os.path.join(self.server.server_dir, self.path[1:]), "rb") as f: - self.wfile.write(f.read()) - else: - self._set_headers() - with open(os.path.join(self.server.server_dir, "index.html"), "rb") as f: - self.wfile.write(f.read()) - - def do_HEAD(self): - self._set_headers() - - def do_POST(self): - resp_code = 200 - out_string = "" - if self.path == "/prompt": + return web.json_response(out) + + @routes.get("/queue") + async def get_queue(request): + queue_info = {} + current_queue = self.prompt_queue.get_current_queue() + queue_info['queue_running'] = current_queue[0] + queue_info['queue_pending'] = current_queue[1] + return web.json_response(queue_info) + + @routes.post("/prompt") + async def post_prompt(request): print("got prompt") - data_string = self.rfile.read(int(self.headers['Content-Length'])) - json_data = json.loads(data_string) + resp_code = 200 + out_string = "" + json_data = await request.json() + if "number" in json_data: number = float(json_data['number']) else: - number = self.server.number + number = self.number if "front" in json_data: if json_data['front']: number = -number - self.server.number += 1 + self.number += 1 if "prompt" in json_data: prompt = json_data["prompt"] valid = validate_prompt(prompt) @@ -418,46 +463,54 @@ class PromptServer(BaseHTTPRequestHandler): if "extra_data" in json_data: extra_data = json_data["extra_data"] if valid[0]: - self.server.prompt_queue.put((number, id(prompt), prompt, extra_data)) + self.prompt_queue.put((number, id(prompt), prompt, extra_data)) else: resp_code = 400 out_string = valid[1] print("invalid prompt:", valid[1]) - elif self.path == "/queue": - data_string = self.rfile.read(int(self.headers['Content-Length'])) - json_data = json.loads(data_string) + + return web.Response(body=out_string, status=resp_code) + + @routes.post("/queue") + async def post_queue(request): + json_data = await request.json() if "clear" in json_data: if json_data["clear"]: - self.server.prompt_queue.wipe_queue() + self.prompt_queue.wipe_queue() if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: delete_func = lambda a: a[1] == int(id_to_delete) - self.server.prompt_queue.delete_queue_item(delete_func) + self.prompt_queue.delete_queue_item(delete_func) + + return web.Response(status=200) - self._set_headers(code=resp_code) - self.end_headers() - self.wfile.write(out_string.encode('utf8')) - return + self.app.add_routes(routes) + self.app.add_routes([ + web.static('/', os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit")), + ]) - -def run(prompt_queue, address='', port=8188): - server_address = (address, port) - httpd = HTTPServer(server_address, PromptServer) - httpd.server_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit") - httpd.prompt_queue = prompt_queue - httpd.number = 0 - if server_address[0] == '': - addr = '0.0.0.0' - else: - addr = server_address[0] +async def start_server(server, address, port): + runner = web.AppRunner(server.app) + await runner.setup() + site = web.TCPSite(runner, address, port) + await site.start() + + if address == '': + address = '0.0.0.0' print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(addr, server_address[1])) - httpd.serve_forever() + print("To see the GUI go to: http://{}:{}".format(address, port)) +async def run(prompt_queue, socket_handler, address='', port=8188): + server = PromptServer(prompt_queue, socket_handler) + await asyncio.gather(start_server(server, address, port), socket_handler.publish_loop()) if __name__ == "__main__": - q = PromptQueue() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + socket_handler = SocketHandler(loop) + q = PromptQueue(socket_handler) threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() if '--listen' in sys.argv: address = '0.0.0.0' @@ -471,6 +524,9 @@ if __name__ == "__main__": except: pass - run(q, address=address, port=port) + try: + loop.run_until_complete(run(q, socket_handler, address=address, port=port)) + except KeyboardInterrupt: + pass