Compare commits

...

4 Commits

Author SHA1 Message Date
Alexander Piskun
91b7bd74bf
Merge 1741ad55ac into ff838657fa 2025-01-09 09:12:29 -05:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
bigcat88
1741ad55ac
added ability to specify custom "prompt_id"
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-24 18:54:43 +02:00
bigcat88
d6edc2c53c
refactor: moved logic of "post_prompt" endpoint to sub-function
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-20 09:41:28 +02:00
2 changed files with 43 additions and 36 deletions

View File

@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@ -615,39 +615,12 @@ class PromptServer():
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt")
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
result = self.put_prompt_in_queue(await request.json())
if "error" in result:
if result.get("node_errors"):
logging.warning("invalid prompt: {}".format(result["error"]))
return web.json_response(result, status=400)
return web.json_response(result)
@routes.post("/queue")
async def post_queue(request):
@ -692,6 +665,41 @@ class PromptServer():
return web.Response(status=200)
def put_prompt_in_queue(self, json_data):
json_data = self.trigger_on_prompt(json_data)
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
if "prompt_id" in json_data:
prompt_id = json_data["prompt_id"]
else:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
else:
return {"error": valid[1], "node_errors": valid[3]}
else:
return {"error": "no prompt", "node_errors": []}
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)