mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Make lumina model work with any latent resolution.
This commit is contained in:
parent
095d867147
commit
4027466c80
@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
@ -594,6 +595,8 @@ class NextDiT(nn.Module):
|
|||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
"""
|
"""
|
||||||
Forward pass of NextDiT.
|
Forward pass of NextDiT.
|
||||||
t: (N,) tensor of diffusion timesteps
|
t: (N,) tensor of diffusion timesteps
|
||||||
@ -613,7 +616,7 @@ class NextDiT(nn.Module):
|
|||||||
x = layer(x, mask, freqs_cis, adaln_input)
|
x = layer(x, mask, freqs_cis, adaln_input)
|
||||||
|
|
||||||
x = self.final_layer(x, adaln_input)
|
x = self.final_layer(x, adaln_input)
|
||||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
|
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||||
|
|
||||||
return -x
|
return -x
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user