mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
a52aa9f4b5
Reworked sockets to use socketio Added progress to nodes Added highlight to active node Added preview to saveimage node
173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
import os
|
|
import sys
|
|
import asyncio
|
|
import nodes
|
|
import main
|
|
|
|
try:
|
|
import aiohttp
|
|
from aiohttp import web
|
|
except ImportError:
|
|
print("Module 'aiohttp' not installed. Please install it via:")
|
|
print("pip install aiohttp")
|
|
print("or")
|
|
print("pip install -r requirements.txt")
|
|
sys.exit()
|
|
|
|
try:
|
|
import socketio
|
|
except ImportError:
|
|
print("Module 'python-socketio' not installed. Please install it via:")
|
|
print("pip install python-socketio")
|
|
print("or")
|
|
print("pip install -r requirements.txt")
|
|
sys.exit()
|
|
|
|
|
|
class PromptServer():
|
|
def __init__(self, loop):
|
|
self.prompt_queue = None
|
|
self.loop = loop
|
|
self.messages = asyncio.Queue()
|
|
self.number = 0
|
|
self.app = web.Application()
|
|
self.sio = socketio.AsyncServer()
|
|
self.sio.attach(self.app)
|
|
self.web_root = os.path.join(os.path.dirname(
|
|
os.path.realpath(__file__)), "webshit")
|
|
routes = web.RouteTableDef()
|
|
|
|
@self.sio.event
|
|
async def connect(sid, environ):
|
|
await self.sio.emit("status", self.get_queue_info(), sid)
|
|
|
|
@routes.get("/")
|
|
async def get_root(request):
|
|
return web.FileResponse(os.path.join(self.web_root, "index.html"))
|
|
|
|
@routes.get("/view/{file}")
|
|
async def view_image(request):
|
|
if "file" in request.match_info:
|
|
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
|
file = request.match_info["file"]
|
|
file = os.path.splitext(os.path.basename(file))[0] + ".png"
|
|
file = os.path.join(output_dir, file)
|
|
if os.path.isfile(file):
|
|
return web.FileResponse(file)
|
|
|
|
return web.Response(status=404)
|
|
|
|
@routes.get("/prompt")
|
|
async def get_prompt(request):
|
|
return web.json_response(self.get_queue_info())
|
|
|
|
@routes.get("/object_info")
|
|
async def get_object_info(request):
|
|
out = {}
|
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
|
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
|
info = {}
|
|
info['input'] = obj_class.INPUT_TYPES()
|
|
info['output'] = obj_class.RETURN_TYPES
|
|
info['name'] = x #TODO
|
|
info['description'] = ''
|
|
info['category'] = 'sd'
|
|
if hasattr(obj_class, 'CATEGORY'):
|
|
info['category'] = obj_class.CATEGORY
|
|
out[x] = info
|
|
return web.json_response(out)
|
|
|
|
@routes.get("/queue")
|
|
async def get_queue(request):
|
|
queue_info = {}
|
|
current_queue = self.prompt_queue.get_current_queue()
|
|
queue_info['queue_running'] = current_queue[0]
|
|
queue_info['queue_pending'] = current_queue[1]
|
|
return web.json_response(queue_info)
|
|
|
|
@routes.post("/prompt")
|
|
async def post_prompt(request):
|
|
print("got prompt")
|
|
resp_code = 200
|
|
out_string = ""
|
|
json_data = await request.json()
|
|
|
|
if "number" in json_data:
|
|
number = float(json_data['number'])
|
|
else:
|
|
number = self.number
|
|
if "front" in json_data:
|
|
if json_data['front']:
|
|
number = -number
|
|
|
|
self.number += 1
|
|
|
|
if "prompt" in json_data:
|
|
prompt = json_data["prompt"]
|
|
valid = main.validate_prompt(prompt)
|
|
extra_data = {}
|
|
if "extra_data" in json_data:
|
|
extra_data = json_data["extra_data"]
|
|
|
|
if "client_id" in json_data:
|
|
extra_data["client_id"] = json_data["client_id"]
|
|
if valid[0]:
|
|
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
|
else:
|
|
resp_code = 400
|
|
out_string = valid[1]
|
|
print("invalid prompt:", valid[1])
|
|
|
|
return web.Response(body=out_string, status=resp_code)
|
|
|
|
@routes.post("/queue")
|
|
async def post_queue(request):
|
|
json_data = await request.json()
|
|
if "clear" in json_data:
|
|
if json_data["clear"]:
|
|
self.prompt_queue.wipe_queue()
|
|
if "delete" in json_data:
|
|
to_delete = json_data['delete']
|
|
for id_to_delete in to_delete:
|
|
delete_func = lambda a: a[1] == int(id_to_delete)
|
|
self.prompt_queue.delete_queue_item(delete_func)
|
|
|
|
return web.Response(status=200)
|
|
|
|
self.app.add_routes(routes)
|
|
self.app.add_routes([
|
|
web.static('/', self.web_root),
|
|
])
|
|
|
|
def get_queue_info(self):
|
|
prompt_info = {}
|
|
exec_info = {}
|
|
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
|
|
prompt_info['exec_info'] = exec_info
|
|
return prompt_info
|
|
|
|
async def send(self, event, data, sid=None):
|
|
await self.sio.emit(event, data, to=sid)
|
|
|
|
def send_sync(self, event, data, sid=None):
|
|
self.loop.call_soon_threadsafe(
|
|
self.messages.put_nowait, (event, data, sid))
|
|
|
|
def queue_updated(self):
|
|
self.send_sync("status", self.get_queue_info())
|
|
|
|
async def publish_loop(self):
|
|
while True:
|
|
msg = await self.messages.get()
|
|
await self.send(*msg)
|
|
|
|
async def start(self, address, port):
|
|
runner = web.AppRunner(self.app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, address, port)
|
|
await site.start()
|
|
|
|
if address == '':
|
|
address = '0.0.0.0'
|
|
print("Starting server\n")
|
|
print("To see the GUI go to: http://{}:{}".format(address, port)) |