Fix lumina 2 batches.

This commit is contained in:
comfyanonymous 2025-02-04 21:48:11 -05:00
parent 016b219dcc
commit a57d635c5f

View File

@ -581,12 +581,13 @@ class NextDiT(nn.Module):
flat_x.append(img) flat_x.append(img)
x = flat_x x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz): for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i] padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, :l_effective_img_len[i]] = True padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed) padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)