Refactor skyreels i2v code.

This commit is contained in:
comfyanonymous 2025-03-04 00:15:45 -05:00
parent 8362199ee7
commit 7c7c70c400
2 changed files with 27 additions and 10 deletions

View File

@ -185,6 +185,11 @@ class BaseModel(torch.nn.Module):
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5:
if concat_latent_image.shape[-3] < noise.shape[-3]:
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
else:
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
@ -213,6 +218,11 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise))
elif ck == "mask_inverted":
cond_concat.append(torch.zeros_like(noise)[:, :1])
if ck == "concat_image":
if concat_latent_image is not None:
cond_concat.append(concat_latent_image.to(device))
else:
cond_concat.append(torch.zeros_like(noise))
data = torch.cat(cond_concat, dim=1)
return data
return None
@ -872,20 +882,17 @@ class HunyuanVideo(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if image is not None:
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",)
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)

View File

@ -826,6 +826,16 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
return out
class CosmosT2V(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
@ -939,6 +949,6 @@ class WAN21_I2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
models += [SVD_img2vid]