diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 40234b64..a50a70ae 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -3,6 +3,7 @@ import torch class LatentFormat: scale_factor = 1.0 latent_channels = 4 + latent_dimensions = 2 latent_rgb_factors = None latent_rgb_factors_bias = None taesd_decoder_name = None @@ -143,6 +144,7 @@ class SD3(LatentFormat): class StableAudio1(LatentFormat): latent_channels = 64 + latent_dimensions = 1 class Flux(SD3): latent_channels = 16 @@ -178,6 +180,7 @@ class Flux(SD3): class Mochi(LatentFormat): latent_channels = 12 + latent_dimensions = 3 def __init__(self): self.scale_factor = 1.0 @@ -219,6 +222,8 @@ class Mochi(LatentFormat): class LTXV(LatentFormat): latent_channels = 128 + latent_dimensions = 3 + def __init__(self): self.latent_rgb_factors = [ [ 1.1202e-02, -6.3815e-04, -1.0021e-02], @@ -355,6 +360,7 @@ class LTXV(LatentFormat): class HunyuanVideo(LatentFormat): latent_channels = 16 + latent_dimensions = 3 scale_factor = 0.476986 latent_rgb_factors = [ [-0.0395, -0.0331, 0.0445], diff --git a/comfy/sample.py b/comfy/sample.py index 9974e065..be5a7e24 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -25,9 +25,11 @@ def prepare_noise(latent_image, seed, noise_inds=None): return noises def fix_empty_latent_channels(model, latent_image): - latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels - if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1) + latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels + if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if latent_format.latent_dimensions == 3 and latent_image.ndim == 4: + latent_image = latent_image.unsqueeze(2) return latent_image def prepare_sampling(model, noise_shape, positive, negative, noise_mask):