Removed socketio

This commit is contained in:
pythongosssss 2023-02-25 12:31:16 +00:00 committed by GitHub
parent a9c57849b7
commit 70b3311478
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 41 deletions

View File

@ -9,4 +9,3 @@ safetensors
pytorch_lightning pytorch_lightning
aiohttp aiohttp
accelerate accelerate
python-socketio

View File

@ -3,6 +3,8 @@ import sys
import asyncio import asyncio
import nodes import nodes
import main import main
import uuid
import json
try: try:
import aiohttp import aiohttp
@ -14,16 +16,6 @@ except ImportError:
print("pip install -r requirements.txt") print("pip install -r requirements.txt")
sys.exit() sys.exit()
try:
import socketio
except ImportError:
print("Module 'python-socketio' not installed. Please install it via:")
print("pip install python-socketio")
print("or")
print("pip install -r requirements.txt")
sys.exit()
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
self.prompt_queue = None self.prompt_queue = None
@ -31,15 +23,26 @@ class PromptServer():
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.number = 0 self.number = 0
self.app = web.Application() self.app = web.Application()
self.sio = socketio.AsyncServer() self.sockets = dict()
self.sio.attach(self.app)
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "webshit") os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef() routes = web.RouteTableDef()
@self.sio.event @routes.get('/ws')
async def connect(sid, environ): async def websocket_handler(request):
await self.sio.emit("status", self.get_queue_info(), sid) ws = web.WebSocketResponse()
await ws.prepare(request)
sid = uuid.uuid4().hex
self.sockets[sid] = ws
try:
# Send initial state to the new client
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
self.sockets.pop(sid)
return ws
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
@ -164,14 +167,23 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
await self.sio.emit(event, data, to=sid) message = {"type": event, "data": data}
if isinstance(message, str) == False:
message = json.dumps(message)
if sid is None:
for ws in self.sockets.values():
await ws.send_str(message)
elif sid in self.sockets:
await self.sockets[sid].send_str(message)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid)) self.messages.put_nowait, (event, data, sid))
def queue_updated(self): def queue_updated(self):
self.send_sync("status", self.get_queue_info()) self.send_sync("status", { "status": self.get_queue_info() })
async def publish_loop(self): async def publish_loop(self):
while True: while True:

View File

@ -2,7 +2,6 @@
<head> <head>
<link rel="stylesheet" type="text/css" href="litegraph.css"> <link rel="stylesheet" type="text/css" href="litegraph.css">
<script type="text/javascript" src="litegraph.core.js"></script> <script type="text/javascript" src="litegraph.core.js"></script>
<script type="text/javascript" src="socket.io.min.js"></script>
</head> </head>
<style> <style>
.customtext_input { .customtext_input {
@ -627,46 +626,77 @@ function setRunningNode(id) {
document.getElementById("queuesize").innerHTML = "Queue size: " + (data ? data.exec_info.queue_remaining : "ERR"); document.getElementById("queuesize").innerHTML = "Queue size: " + (data ? data.exec_info.queue_remaining : "ERR");
} }
//fix for colab and other things that don't support websockets.
function manually_fetch_queue() {
fetch('/prompt')
.then(response => response.json())
.then(data => {
updateStatus(data);
}).catch((response) => {updateStatus(null)});
}
let ws;
function createSocket(isReconnect) { function createSocket(isReconnect) {
if(ws) return;
let opened = false; let opened = false;
const ws = io(); ws = new WebSocket(`ws${window.location.protocol === "https:"? "s" : ""}://${location.host}/ws`);
ws.on("connect", () => { ws.addEventListener("open", () => {
clientId = ws.id;
if(opened) {
closeModal();
} else {
opened = true; opened = true;
if(isReconnect) {
closeModal();
} }
}); });
ws.on("disconnect", () => { ws.addEventListener("error", () => {
if(ws) ws.close();
manually_fetch_queue();
});
ws.addEventListener("close", () => {
setTimeout(() => {
ws = null;
createSocket(true);
}, 300);
if(opened) { if(opened) {
updateStatus(null); updateStatus(null);
showModal("Reconnecting..."); showModal("Reconnecting...");
} }
}); });
ws.on("status", (data) => { ws.addEventListener("message", (event) => {
updateStatus(data); try {
}); const msg = JSON.parse(event.data);
console.log(msg.type, msg.data);
ws.on("progress", (data) => { switch(msg.type) {
updateNodeProgress(data); case "status":
}); if(msg.data.sid) {
clientId = msg.data.sid;
ws.on("executing", (data) => { }
setRunningNode(data.node); updateStatus(msg.data.status);
}); break;
case "progress":
ws.on("executed", (data) => { updateNodeProgress(msg.data)
nodeOutputs[data.node] = data.output; break;
case "executing":
setRunningNode(msg.data.node);
break;
case "executed":
nodeOutputs[msg.data.node] = msg.data.output;
break;
default:
throw new Error("Unknown message type")
}
} catch (error) {
console.warn("Unhandled message:", event.data)
}
}); });
} }
createSocket(); createSocket();
})(); })();
function clearGraph() { function clearGraph() {
graph.clear(); graph.clear();
} }