Compare commits

..

1 Commits

Author SHA1 Message Date
Konin Oleg
78d1cdd47a
Merge fcb5988c24 into d0f3752e33 2025-01-07 17:42:22 -05:00
3 changed files with 6 additions and 6 deletions

View File

@ -456,8 +456,9 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@ -111,7 +111,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self):
n = CLIP(no_init=True)
@ -898,7 +898,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU")
logging.info("loaded straight to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision)

View File

@ -4,8 +4,7 @@ lint.ignore = ["ALL"]
# Enable specific rules
lint.select = [
"S307", # suspicious-eval-usage
"S102", # exec
"T", # print-usage
"T201", # print-usage
"W",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f