2025-03-05 00:13:49 -05:00
import io
2024-11-22 08:44:42 -05:00
import nodes
import node_helpers
import torch
import comfy . model_management
import comfy . model_sampling
2025-03-05 00:13:49 -05:00
import comfy . utils
2024-11-22 08:44:42 -05:00
import math
2025-03-05 00:13:49 -05:00
import numpy as np
import av
from comfy . ldm . lightricks . symmetric_patchifier import SymmetricPatchifier , latent_to_pixel_coords
2024-11-22 08:44:42 -05:00
class EmptyLTXVLatentVideo :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " width " : ( " INT " , { " default " : 768 , " min " : 64 , " max " : nodes . MAX_RESOLUTION , " step " : 32 } ) ,
" height " : ( " INT " , { " default " : 512 , " min " : 64 , " max " : nodes . MAX_RESOLUTION , " step " : 32 } ) ,
2024-11-23 21:33:08 -05:00
" length " : ( " INT " , { " default " : 97 , " min " : 1 , " max " : nodes . MAX_RESOLUTION , " step " : 8 } ) ,
2024-11-22 08:44:42 -05:00
" batch_size " : ( " INT " , { " default " : 1 , " min " : 1 , " max " : 4096 } ) } }
RETURN_TYPES = ( " LATENT " , )
FUNCTION = " generate "
CATEGORY = " latent/video/ltxv "
def generate ( self , width , height , length , batch_size = 1 ) :
latent = torch . zeros ( [ batch_size , 128 , ( ( length - 1 ) / / 8 ) + 1 , height / / 32 , width / / 32 ] , device = comfy . model_management . intermediate_device ( ) )
return ( { " samples " : latent } , )
class LTXVImgToVideo :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " positive " : ( " CONDITIONING " , ) ,
" negative " : ( " CONDITIONING " , ) ,
" vae " : ( " VAE " , ) ,
" image " : ( " IMAGE " , ) ,
" width " : ( " INT " , { " default " : 768 , " min " : 64 , " max " : nodes . MAX_RESOLUTION , " step " : 32 } ) ,
" height " : ( " INT " , { " default " : 512 , " min " : 64 , " max " : nodes . MAX_RESOLUTION , " step " : 32 } ) ,
" length " : ( " INT " , { " default " : 97 , " min " : 9 , " max " : nodes . MAX_RESOLUTION , " step " : 8 } ) ,
2024-12-06 05:48:34 -05:00
" batch_size " : ( " INT " , { " default " : 1 , " min " : 1 , " max " : 4096 } ) ,
2025-04-28 19:59:17 +03:00
" strength " : ( " FLOAT " , { " default " : 1.0 , " min " : 0.0 , " max " : 1.0 } ) ,
2024-12-06 05:48:34 -05:00
} }
2024-11-22 08:44:42 -05:00
RETURN_TYPES = ( " CONDITIONING " , " CONDITIONING " , " LATENT " )
RETURN_NAMES = ( " positive " , " negative " , " latent " )
CATEGORY = " conditioning/video_models "
FUNCTION = " generate "
2025-04-28 19:59:17 +03:00
def generate ( self , positive , negative , image , vae , width , height , length , batch_size , strength ) :
2024-11-22 08:44:42 -05:00
pixels = comfy . utils . common_upscale ( image . movedim ( - 1 , 1 ) , width , height , " bilinear " , " center " ) . movedim ( 1 , - 1 )
encode_pixels = pixels [ : , : , : , : 3 ]
t = vae . encode ( encode_pixels )
latent = torch . zeros ( [ batch_size , 128 , ( ( length - 1 ) / / 8 ) + 1 , height / / 32 , width / / 32 ] , device = comfy . model_management . intermediate_device ( ) )
latent [ : , : , : t . shape [ 2 ] ] = t
2025-03-05 00:13:49 -05:00
conditioning_latent_frames_mask = torch . ones (
( batch_size , 1 , latent . shape [ 2 ] , 1 , 1 ) ,
dtype = torch . float32 ,
device = latent . device ,
)
2025-04-28 19:59:17 +03:00
conditioning_latent_frames_mask [ : , : , : t . shape [ 2 ] ] = 1.0 - strength
2025-03-05 00:13:49 -05:00
return ( positive , negative , { " samples " : latent , " noise_mask " : conditioning_latent_frames_mask } , )
def conditioning_get_any_value ( conditioning , key , default = None ) :
for t in conditioning :
if key in t [ 1 ] :
return t [ 1 ] [ key ]
return default
def get_noise_mask ( latent ) :
noise_mask = latent . get ( " noise_mask " , None )
latent_image = latent [ " samples " ]
if noise_mask is None :
batch_size , _ , latent_length , _ , _ = latent_image . shape
noise_mask = torch . ones (
( batch_size , 1 , latent_length , 1 , 1 ) ,
dtype = torch . float32 ,
device = latent_image . device ,
)
else :
noise_mask = noise_mask . clone ( )
return noise_mask
def get_keyframe_idxs ( cond ) :
keyframe_idxs = conditioning_get_any_value ( cond , " keyframe_idxs " , None )
if keyframe_idxs is None :
return None , 0
num_keyframes = torch . unique ( keyframe_idxs [ : , 0 ] ) . shape [ 0 ]
return keyframe_idxs , num_keyframes
class LTXVAddGuide :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " positive " : ( " CONDITIONING " , ) ,
" negative " : ( " CONDITIONING " , ) ,
" vae " : ( " VAE " , ) ,
" latent " : ( " LATENT " , ) ,
2025-03-10 10:11:48 +02:00
" image " : ( " IMAGE " , { " tooltip " : " Image or video to condition the latent video on. Must be 8*n + 1 frames. "
2025-03-05 00:13:49 -05:00
" If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames. " } ) ,
" frame_idx " : ( " INT " , { " default " : 0 , " min " : - 9999 , " max " : 9999 ,
2025-03-10 10:11:48 +02:00
" tooltip " : " Frame index to start the conditioning at. For single-frame images or "
" videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
" frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
" the nearest multiple of 8. Negative values are counted from the end of the video. " } ) ,
2025-03-05 00:13:49 -05:00
" strength " : ( " FLOAT " , { " default " : 1.0 , " min " : 0.0 , " max " : 1.0 , " step " : 0.01 } ) ,
}
}
RETURN_TYPES = ( " CONDITIONING " , " CONDITIONING " , " LATENT " )
RETURN_NAMES = ( " positive " , " negative " , " latent " )
CATEGORY = " conditioning/video_models "
FUNCTION = " generate "
def __init__ ( self ) :
self . _num_prefix_frames = 2
self . _patchifier = SymmetricPatchifier ( 1 )
def encode ( self , vae , latent_width , latent_height , images , scale_factors ) :
time_scale_factor , width_scale_factor , height_scale_factor = scale_factors
images = images [ : ( images . shape [ 0 ] - 1 ) / / time_scale_factor * time_scale_factor + 1 ]
pixels = comfy . utils . common_upscale ( images . movedim ( - 1 , 1 ) , latent_width * width_scale_factor , latent_height * height_scale_factor , " bilinear " , crop = " disabled " ) . movedim ( 1 , - 1 )
encode_pixels = pixels [ : , : , : , : 3 ]
t = vae . encode ( encode_pixels )
return encode_pixels , t
2025-03-10 10:11:48 +02:00
def get_latent_index ( self , cond , latent_length , guide_length , frame_idx , scale_factors ) :
2025-03-05 00:13:49 -05:00
time_scale_factor , _ , _ = scale_factors
_ , num_keyframes = get_keyframe_idxs ( cond )
latent_count = latent_length - num_keyframes
2025-03-10 10:11:48 +02:00
frame_idx = frame_idx if frame_idx > = 0 else max ( ( latent_count - 1 ) * time_scale_factor + 1 + frame_idx , 0 )
2025-07-02 21:34:51 +02:00
if guide_length > 1 and frame_idx != 0 :
frame_idx = ( frame_idx - 1 ) / / time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0
2025-03-05 00:13:49 -05:00
latent_idx = ( frame_idx + time_scale_factor - 1 ) / / time_scale_factor
return frame_idx , latent_idx
def add_keyframe_index ( self , cond , frame_idx , guiding_latent , scale_factors ) :
keyframe_idxs , _ = get_keyframe_idxs ( cond )
_ , latent_coords = self . _patchifier . patchify ( guiding_latent )
2025-07-02 21:34:51 +02:00
pixel_coords = latent_to_pixel_coords ( latent_coords , scale_factors , causal_fix = frame_idx == 0 ) # we need the causal fix only if we're placing the new latents at index 0
2025-03-05 00:13:49 -05:00
pixel_coords [ : , 0 ] + = frame_idx
if keyframe_idxs is None :
keyframe_idxs = pixel_coords
else :
keyframe_idxs = torch . cat ( [ keyframe_idxs , pixel_coords ] , dim = 2 )
return node_helpers . conditioning_set_values ( cond , { " keyframe_idxs " : keyframe_idxs } )
def append_keyframe ( self , positive , negative , frame_idx , latent_image , noise_mask , guiding_latent , strength , scale_factors ) :
2025-04-28 20:42:04 +03:00
_ , latent_idx = self . get_latent_index (
cond = positive ,
latent_length = latent_image . shape [ 2 ] ,
guide_length = guiding_latent . shape [ 2 ] ,
frame_idx = frame_idx ,
scale_factors = scale_factors ,
)
noise_mask [ : , : , latent_idx : latent_idx + guiding_latent . shape [ 2 ] ] = 1.0
2025-03-05 00:13:49 -05:00
positive = self . add_keyframe_index ( positive , frame_idx , guiding_latent , scale_factors )
negative = self . add_keyframe_index ( negative , frame_idx , guiding_latent , scale_factors )
mask = torch . full (
( noise_mask . shape [ 0 ] , 1 , guiding_latent . shape [ 2 ] , 1 , 1 ) ,
1.0 - strength ,
dtype = noise_mask . dtype ,
device = noise_mask . device ,
)
latent_image = torch . cat ( [ latent_image , guiding_latent ] , dim = 2 )
noise_mask = torch . cat ( [ noise_mask , mask ] , dim = 2 )
return positive , negative , latent_image , noise_mask
def replace_latent_frames ( self , latent_image , noise_mask , guiding_latent , latent_idx , strength ) :
cond_length = guiding_latent . shape [ 2 ]
assert latent_image . shape [ 2 ] > = latent_idx + cond_length , " Conditioning frames exceed the length of the latent sequence. "
mask = torch . full (
( noise_mask . shape [ 0 ] , 1 , cond_length , 1 , 1 ) ,
1.0 - strength ,
dtype = noise_mask . dtype ,
device = noise_mask . device ,
)
latent_image = latent_image . clone ( )
noise_mask = noise_mask . clone ( )
latent_image [ : , : , latent_idx : latent_idx + cond_length ] = guiding_latent
noise_mask [ : , : , latent_idx : latent_idx + cond_length ] = mask
return latent_image , noise_mask
def generate ( self , positive , negative , vae , latent , image , frame_idx , strength ) :
scale_factors = vae . downscale_index_formula
latent_image = latent [ " samples " ]
noise_mask = get_noise_mask ( latent )
_ , _ , latent_length , latent_height , latent_width = latent_image . shape
image , t = self . encode ( vae , latent_width , latent_height , image , scale_factors )
2025-03-10 10:11:48 +02:00
frame_idx , latent_idx = self . get_latent_index ( positive , latent_length , len ( image ) , frame_idx , scale_factors )
2025-03-05 00:13:49 -05:00
assert latent_idx + t . shape [ 2 ] < = latent_length , " Conditioning frames exceed the length of the latent sequence. "
num_prefix_frames = min ( self . _num_prefix_frames , t . shape [ 2 ] )
positive , negative , latent_image , noise_mask = self . append_keyframe (
positive ,
negative ,
frame_idx ,
latent_image ,
noise_mask ,
t [ : , : , : num_prefix_frames ] ,
strength ,
scale_factors ,
)
latent_idx + = num_prefix_frames
t = t [ : , : , num_prefix_frames : ]
if t . shape [ 2 ] == 0 :
return ( positive , negative , { " samples " : latent_image , " noise_mask " : noise_mask } , )
latent_image , noise_mask = self . replace_latent_frames (
latent_image ,
noise_mask ,
t ,
latent_idx ,
strength ,
)
return ( positive , negative , { " samples " : latent_image , " noise_mask " : noise_mask } , )
class LTXVCropGuides :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " positive " : ( " CONDITIONING " , ) ,
" negative " : ( " CONDITIONING " , ) ,
" latent " : ( " LATENT " , ) ,
}
}
RETURN_TYPES = ( " CONDITIONING " , " CONDITIONING " , " LATENT " )
RETURN_NAMES = ( " positive " , " negative " , " latent " )
CATEGORY = " conditioning/video_models "
FUNCTION = " crop "
def __init__ ( self ) :
self . _patchifier = SymmetricPatchifier ( 1 )
def crop ( self , positive , negative , latent ) :
latent_image = latent [ " samples " ] . clone ( )
noise_mask = get_noise_mask ( latent )
_ , num_keyframes = get_keyframe_idxs ( positive )
2025-03-05 15:47:32 +02:00
if num_keyframes == 0 :
return ( positive , negative , { " samples " : latent_image , " noise_mask " : noise_mask } , )
2025-03-05 00:13:49 -05:00
latent_image = latent_image [ : , : , : - num_keyframes ]
noise_mask = noise_mask [ : , : , : - num_keyframes ]
positive = node_helpers . conditioning_set_values ( positive , { " keyframe_idxs " : None } )
negative = node_helpers . conditioning_set_values ( negative , { " keyframe_idxs " : None } )
return ( positive , negative , { " samples " : latent_image , " noise_mask " : noise_mask } , )
2024-11-22 08:44:42 -05:00
class LTXVConditioning :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " positive " : ( " CONDITIONING " , ) ,
" negative " : ( " CONDITIONING " , ) ,
" frame_rate " : ( " FLOAT " , { " default " : 25.0 , " min " : 0.0 , " max " : 1000.0 , " step " : 0.01 } ) ,
} }
RETURN_TYPES = ( " CONDITIONING " , " CONDITIONING " )
RETURN_NAMES = ( " positive " , " negative " )
FUNCTION = " append "
CATEGORY = " conditioning/video_models "
def append ( self , positive , negative , frame_rate ) :
positive = node_helpers . conditioning_set_values ( positive , { " frame_rate " : frame_rate } )
negative = node_helpers . conditioning_set_values ( negative , { " frame_rate " : frame_rate } )
return ( positive , negative )
class ModelSamplingLTXV :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " model " : ( " MODEL " , ) ,
" max_shift " : ( " FLOAT " , { " default " : 2.05 , " min " : 0.0 , " max " : 100.0 , " step " : 0.01 } ) ,
" base_shift " : ( " FLOAT " , { " default " : 0.95 , " min " : 0.0 , " max " : 100.0 , " step " : 0.01 } ) ,
} ,
" optional " : { " latent " : ( " LATENT " , ) , }
}
RETURN_TYPES = ( " MODEL " , )
FUNCTION = " patch "
CATEGORY = " advanced/model "
2024-12-06 05:48:34 -05:00
def patch ( self , model , max_shift , base_shift , latent = None ) :
2024-11-22 08:44:42 -05:00
m = model . clone ( )
if latent is None :
tokens = 4096
else :
tokens = math . prod ( latent [ " samples " ] . shape [ 2 : ] )
x1 = 1024
x2 = 4096
mm = ( max_shift - base_shift ) / ( x2 - x1 )
b = base_shift - mm * x1
shift = ( tokens ) * mm + b
sampling_base = comfy . model_sampling . ModelSamplingFlux
sampling_type = comfy . model_sampling . CONST
class ModelSamplingAdvanced ( sampling_base , sampling_type ) :
pass
model_sampling = ModelSamplingAdvanced ( model . model . model_config )
model_sampling . set_parameters ( shift = shift )
m . add_object_patch ( " model_sampling " , model_sampling )
2024-12-06 12:46:08 +02:00
2024-11-22 08:44:42 -05:00
return ( m , )
class LTXVScheduler :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " :
{ " steps " : ( " INT " , { " default " : 20 , " min " : 1 , " max " : 10000 } ) ,
" max_shift " : ( " FLOAT " , { " default " : 2.05 , " min " : 0.0 , " max " : 100.0 , " step " : 0.01 } ) ,
" base_shift " : ( " FLOAT " , { " default " : 0.95 , " min " : 0.0 , " max " : 100.0 , " step " : 0.01 } ) ,
" stretch " : ( " BOOLEAN " , {
" default " : True ,
" tooltip " : " Stretch the sigmas to be in the range [terminal, 1]. "
} ) ,
" terminal " : (
" FLOAT " ,
{
" default " : 0.1 , " min " : 0.0 , " max " : 0.99 , " step " : 0.01 ,
" tooltip " : " The terminal value of the sigmas after stretching. "
} ,
) ,
} ,
" optional " : { " latent " : ( " LATENT " , ) , }
}
RETURN_TYPES = ( " SIGMAS " , )
CATEGORY = " sampling/custom_sampling/schedulers "
FUNCTION = " get_sigmas "
def get_sigmas ( self , steps , max_shift , base_shift , stretch , terminal , latent = None ) :
if latent is None :
tokens = 4096
else :
tokens = math . prod ( latent [ " samples " ] . shape [ 2 : ] )
sigmas = torch . linspace ( 1.0 , 0.0 , steps + 1 )
x1 = 1024
x2 = 4096
mm = ( max_shift - base_shift ) / ( x2 - x1 )
b = base_shift - mm * x1
sigma_shift = ( tokens ) * mm + b
power = 1
sigmas = torch . where (
sigmas != 0 ,
math . exp ( sigma_shift ) / ( math . exp ( sigma_shift ) + ( 1 / sigmas - 1 ) * * power ) ,
0 ,
)
# Stretch sigmas so that its final value matches the given terminal value.
if stretch :
non_zero_mask = sigmas != 0
non_zero_sigmas = sigmas [ non_zero_mask ]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z [ - 1 ] / ( 1.0 - terminal )
stretched = 1.0 - ( one_minus_z / scale_factor )
sigmas [ non_zero_mask ] = stretched
return ( sigmas , )
2025-03-05 00:13:49 -05:00
def encode_single_frame ( output_file , image_array : np . ndarray , crf ) :
container = av . open ( output_file , " w " , format = " mp4 " )
try :
stream = container . add_stream (
2025-04-24 10:58:31 -07:00
" libx264 " , rate = 1 , options = { " crf " : str ( crf ) , " preset " : " veryfast " }
2025-03-05 00:13:49 -05:00
)
stream . height = image_array . shape [ 0 ]
stream . width = image_array . shape [ 1 ]
av_frame = av . VideoFrame . from_ndarray ( image_array , format = " rgb24 " ) . reformat (
format = " yuv420p "
)
container . mux ( stream . encode ( av_frame ) )
container . mux ( stream . encode ( ) )
finally :
container . close ( )
def decode_single_frame ( video_file ) :
container = av . open ( video_file )
try :
stream = next ( s for s in container . streams if s . type == " video " )
frame = next ( container . decode ( stream ) )
finally :
container . close ( )
return frame . to_ndarray ( format = " rgb24 " )
def preprocess ( image : torch . Tensor , crf = 29 ) :
if crf == 0 :
return image
2025-03-05 07:18:13 -05:00
image_array = ( image [ : ( image . shape [ 0 ] / / 2 ) * 2 , : ( image . shape [ 1 ] / / 2 ) * 2 ] * 255.0 ) . byte ( ) . cpu ( ) . numpy ( )
2025-03-05 00:13:49 -05:00
with io . BytesIO ( ) as output_file :
encode_single_frame ( output_file , image_array , crf )
video_bytes = output_file . getvalue ( )
with io . BytesIO ( video_bytes ) as video_file :
image_array = decode_single_frame ( video_file )
tensor = torch . tensor ( image_array , dtype = image . dtype , device = image . device ) / 255.0
return tensor
class LTXVPreprocess :
@classmethod
def INPUT_TYPES ( s ) :
return {
" required " : {
" image " : ( " IMAGE " , ) ,
" img_compression " : (
" INT " ,
{
" default " : 35 ,
" min " : 0 ,
" max " : 100 ,
" tooltip " : " Amount of compression to apply on image. " ,
} ,
) ,
}
}
FUNCTION = " preprocess "
RETURN_TYPES = ( " IMAGE " , )
RETURN_NAMES = ( " output_image " , )
CATEGORY = " image "
def preprocess ( self , image , img_compression ) :
2025-03-30 03:03:02 +03:00
output_images = [ ]
for i in range ( image . shape [ 0 ] ) :
output_images . append ( preprocess ( image [ i ] , img_compression ) )
2025-03-05 07:18:13 -05:00
return ( torch . stack ( output_images ) , )
2025-03-05 00:13:49 -05:00
2024-11-22 08:44:42 -05:00
NODE_CLASS_MAPPINGS = {
" EmptyLTXVLatentVideo " : EmptyLTXVLatentVideo ,
" LTXVImgToVideo " : LTXVImgToVideo ,
" ModelSamplingLTXV " : ModelSamplingLTXV ,
" LTXVConditioning " : LTXVConditioning ,
" LTXVScheduler " : LTXVScheduler ,
2025-03-05 00:13:49 -05:00
" LTXVAddGuide " : LTXVAddGuide ,
" LTXVPreprocess " : LTXVPreprocess ,
" LTXVCropGuides " : LTXVCropGuides ,
2024-11-22 08:44:42 -05:00
}