Compare commits

...

5 Commits

Author SHA1 Message Date
Rickard Edén
2bac46cd2a
Merge ceee378d86 into ff838657fa 2025-01-09 09:12:29 -05:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
rickard
ceee378d86 replace 'diffusion_model' with 'transformer' 2024-12-25 18:04:56 +01:00
rickard
21f20638bd move code to model_lora_keys_unet 2024-12-23 09:20:22 +01:00
rickard
31b6852f19 load ltx loras trained with finetrainers 2024-12-23 08:19:34 +01:00
2 changed files with 8 additions and 3 deletions

View File

@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@ -399,6 +399,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
if isinstance(model, comfy.model_base.LTXV):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
return key_map