Make lumina model work with any latent resolution.

This commit is contained in:
comfyanonymous 2025-02-10 00:24:20 -05:00
parent 095d867147
commit 4027466c80

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
from comfy.ldm.modules.attention import optimized_attention_masked
@ -594,6 +595,8 @@ class NextDiT(nn.Module):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
"""
Forward pass of NextDiT.
t: (N,) tensor of diffusion timesteps
@ -613,7 +616,7 @@ class NextDiT(nn.Module):
x = layer(x, mask, freqs_cis, adaln_input)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
return -x