import os import sys import asyncio import nodes import execution import uuid import json import glob try: import aiohttp from aiohttp import web except ImportError: print("Module 'aiohttp' not installed. Please install it via:") print("pip install aiohttp") print("or") print("pip install -r requirements.txt") sys.exit() import mimetypes @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) if request.path.endswith('.js') or request.path.endswith('.css'): response.headers.setdefault('Cache-Control', 'no-cache') return response class PromptServer(): def __init__(self, loop): mimetypes.init(); mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' self.prompt_queue = None self.loop = loop self.messages = asyncio.Queue() self.number = 0 self.app = web.Application(client_max_size=20971520, middlewares=[cache_control]) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") routes = web.RouteTableDef() self.last_node_id = None self.client_id = None @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) sid = request.rel_url.query.get('clientId', '') if sid: # Reusing existing session, remove old self.sockets.pop(sid, None) else: sid = uuid.uuid4().hex self.sockets[sid] = ws try: # Send initial state to the new client await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: await self.send("executing", { "node": self.last_node_id }, sid) async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: print('ws connection closed with exception %s' % ws.exception()) finally: self.sockets.pop(sid, None) return ws @routes.get("/") async def get_root(request): return web.FileResponse(os.path.join(self.web_root, "index.html")) @routes.get("/embeddings") def get_embeddings(self): models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") embed_dir = os.path.join(models_dir, "embeddings") embeddings = nodes.filter_files_extensions(nodes.recursive_search(embed_dir), nodes.supported_pt_extensions) return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings))) @routes.get("/extensions") async def get_extensions(request): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) @routes.post("/upload/image") async def upload_image(request): upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") if not os.path.exists(upload_dir): os.makedirs(upload_dir) post = await request.post() image = post.get("image") if image and image.file: filename = image.filename if not filename: return web.Response(status=400) split = os.path.splitext(filename) i = 1 while os.path.exists(os.path.join(upload_dir, filename)): filename = f"{split[0]} ({i}){split[1]}" i += 1 filepath = os.path.join(upload_dir, filename) with open(filepath, "wb") as f: f.write(image.file.read()) return web.json_response({"name" : filename}) else: return web.Response(status=400) @routes.get("/view/{file}") async def view_image(request): if "file" in request.match_info: type = request.rel_url.query.get("type", "output") if type != "output" and type != "input": return web.Response(status=400) output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) file = request.match_info["file"] file = os.path.basename(file) file = os.path.join(output_dir, file) if os.path.isfile(file): return web.FileResponse(file) return web.Response(status=404) @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: obj_class = nodes.NODE_CLASS_MAPPINGS[x] info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES info['name'] = x #TODO info['description'] = '' info['category'] = 'sd' if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY out[x] = info return web.json_response(out) @routes.get("/history") async def get_history(request): return web.json_response(self.prompt_queue.get_history()) @routes.get("/queue") async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue() queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) @routes.post("/prompt") async def post_prompt(request): print("got prompt") resp_code = 200 out_string = "" json_data = await request.json() if "number" in json_data: number = float(json_data['number']) else: number = self.number if "front" in json_data: if json_data['front']: number = -number self.number += 1 if "prompt" in json_data: prompt = json_data["prompt"] valid = execution.validate_prompt(prompt) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: self.prompt_queue.put((number, id(prompt), prompt, extra_data)) else: resp_code = 400 out_string = valid[1] print("invalid prompt:", valid[1]) return web.Response(body=out_string, status=resp_code) @routes.post("/queue") async def post_queue(request): json_data = await request.json() if "clear" in json_data: if json_data["clear"]: self.prompt_queue.wipe_queue() if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: delete_func = lambda a: a[1] == int(id_to_delete) self.prompt_queue.delete_queue_item(delete_func) return web.Response(status=200) @routes.post("/interrupt") async def post_interrupt(request): nodes.interrupt_processing() return web.Response(status=200) @routes.post("/history") async def post_history(request): json_data = await request.json() if "clear" in json_data: if json_data["clear"]: self.prompt_queue.wipe_history() if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) self.app.add_routes(routes) self.app.add_routes([ web.static('/', self.web_root), ]) def get_queue_info(self): prompt_info = {} exec_info = {} exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() prompt_info['exec_info'] = exec_info return prompt_info async def send(self, event, data, sid=None): message = {"type": event, "data": data} if isinstance(message, str) == False: message = json.dumps(message) if sid is None: for ws in self.sockets.values(): await ws.send_str(message) elif sid in self.sockets: await self.sockets[sid].send_str(message) def send_sync(self, event, data, sid=None): self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) def queue_updated(self): self.send_sync("status", { "status": self.get_queue_info() }) async def publish_loop(self): while True: msg = await self.messages.get() await self.send(*msg) async def start(self, address, port, verbose=True): runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, address, port) await site.start() if address == '': address = '0.0.0.0' if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port))