From d6edc2c53c34b83e81cfe8b9ad735362ea958109 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 20 Dec 2024 09:41:28 +0200 Subject: [PATCH] refactor: moved logic of "post_prompt" endpoint to sub-function Signed-off-by: bigcat88 --- server.py | 71 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/server.py b/server.py index ddd71e06..dfee6887 100644 --- a/server.py +++ b/server.py @@ -613,39 +613,12 @@ class PromptServer(): @routes.post("/prompt") async def post_prompt(request): logging.info("got prompt") - json_data = await request.json() - json_data = self.trigger_on_prompt(json_data) - - if "number" in json_data: - number = float(json_data['number']) - else: - number = self.number - if "front" in json_data: - if json_data['front']: - number = -number - - self.number += 1 - - if "prompt" in json_data: - prompt = json_data["prompt"] - valid = execution.validate_prompt(prompt) - extra_data = {} - if "extra_data" in json_data: - extra_data = json_data["extra_data"] - - if "client_id" in json_data: - extra_data["client_id"] = json_data["client_id"] - if valid[0]: - prompt_id = str(uuid.uuid4()) - outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) - response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} - return web.json_response(response) - else: - logging.warning("invalid prompt: {}".format(valid[1])) - return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) - else: - return web.json_response({"error": "no prompt", "node_errors": []}, status=400) + result = self.put_prompt_in_queue(await request.json()) + if "error" in result: + if result.get("node_errors"): + logging.warning("invalid prompt: {}".format(result["error"])) + return web.json_response(result, status=400) + return web.json_response(result) @routes.post("/queue") async def post_queue(request): @@ -690,6 +663,38 @@ class PromptServer(): return web.Response(status=200) + def put_prompt_in_queue(self, json_data): + json_data = self.trigger_on_prompt(json_data) + + if "number" in json_data: + number = float(json_data['number']) + else: + number = self.number + if "front" in json_data: + if json_data['front']: + number = -number + + self.number += 1 + + if "prompt" in json_data: + prompt = json_data["prompt"] + valid = execution.validate_prompt(prompt) + extra_data = {} + if "extra_data" in json_data: + extra_data = json_data["extra_data"] + + if "client_id" in json_data: + extra_data["client_id"] = json_data["client_id"] + if valid[0]: + prompt_id = str(uuid.uuid4()) + outputs_to_execute = valid[2] + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + return {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} + else: + return {"error": valid[1], "node_errors": valid[3]} + else: + return {"error": "no prompt", "node_errors": []} + async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout)