Compare commits

...

6 Commits

Author SHA1 Message Date
Dr.Lt.Data
8e93bafb1d
Merge 1670612b90 into ff838657fa 2025-01-09 10:56:00 -05:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
Dr.Lt.Data
1670612b90 resolve conflict 2024-09-24 21:25:16 +09:00
Dr.Lt.Data
565d67478a
Merge branch 'master' into improve/extra_model_paths_template 2024-09-24 21:22:31 +09:00
Dr.Lt.Data
128075b781
Update extra_model_paths.yaml.example 2024-08-22 13:41:38 +09:00
Dr.Lt.Data
b86567092b add missing folder paths in extra_model_paths template 2024-08-15 10:03:50 +09:00
2 changed files with 7 additions and 3 deletions

View File

@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@ -36,9 +36,14 @@ a111:
# models/diffusion_models
# models/unet
# embeddings: models/embeddings/
# gligen: models/gligen/
# hypernetworks: models/hypernetworks/
# loras: models/loras/
# photomaker: models/photomaker/
# style_models: models/style_models/
# upscale_models: models/upscale_models/
# vae: models/vae/
# vae_approx: models/vae_approx/
#other_ui:
# base_path: path/to/ui