Compare commits

...

4 Commits

Author SHA1 Message Date
Sergii Dymchenko
a2fc55e467
Merge 76b7c71b6e into ff838657fa 2025-01-09 20:13:10 +08:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
Sergii Dymchenko
76b7c71b6e
Use non-alias 2025-01-07 15:51:19 -08:00
Sergii Dymchenko
e4e2d82b0f
Use torch.special.expm1
This function provides greater precision than `exp(x) - 1` for small values of `x`.

Found with TorchFix https://github.com/pytorch-labs/torchfix/
2025-01-07 15:39:37 -08:00
2 changed files with 3 additions and 4 deletions

View File

@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
return append_zero(sigmas)

View File

@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)