Compare commits

...

3 Commits

Author SHA1 Message Date
Alexander Piskun
fbada07f99
Merge 1f6ab7dbfb into ff838657fa 2025-01-09 09:12:29 -05:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
bigcat88
1f6ab7dbfb
support for "unload_models" flag when creating a task
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-24 17:41:49 +02:00
2 changed files with 10 additions and 3 deletions

View File

@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@ -166,6 +166,14 @@ def prompt_worker(q, server_instance):
queue_item = q.get(timeout=timeout)
if queue_item is not None:
item, item_id = queue_item
if item[3].get("unload_models"):
# For those cases where the flag is set, to clear memory before execution
comfy.model_management.unload_all_models()
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = time.perf_counter()
execution_start_time = time.perf_counter()
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id