From ff838657fac787a0e7aef6bbbf92ea4411bdbd15 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 9 Jan 2025 07:12:03 -0500 Subject: [PATCH] Cleaner handling of attention mask in ltxv model code. --- comfy/ldm/lightricks/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index eeeeaea0..2a02acd6 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module): x = self.patchify_proj(x) timestep = timestep * 1000.0 - attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) - attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this - # attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype) + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)