From a57d635c5f36c28c59ea6513878acde80e4df180 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Feb 2025 21:48:11 -0500 Subject: [PATCH] Fix lumina 2 batches. --- comfy/ldm/lumina/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 4eb7164d..e4b0d34a 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -581,12 +581,13 @@ class NextDiT(nn.Module): flat_x.append(img) x = flat_x padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) + padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) for i in range(bsz): padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, :l_effective_img_len[i]] = True + padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max padded_img_embed = self.x_embedder(padded_img_embed) + padded_img_mask = padded_img_mask.unsqueeze(1) for layer in self.noise_refiner: padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)