diff --git a/comfy/model_base.py b/comfy/model_base.py index 66cd0ded1..cddc4663e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -185,6 +185,11 @@ class BaseModel(torch.nn.Module): if concat_latent_image.shape[1:] != noise.shape[1:]: concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + if noise.ndim == 5: + if concat_latent_image.shape[-3] < noise.shape[-3]: + concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0) + else: + concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]] concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) @@ -213,6 +218,11 @@ class BaseModel(torch.nn.Module): cond_concat.append(self.blank_inpaint_image_like(noise)) elif ck == "mask_inverted": cond_concat.append(torch.zeros_like(noise)[:, :1]) + if ck == "concat_image": + if concat_latent_image is not None: + cond_concat.append(concat_latent_image.to(device)) + else: + cond_concat.append(torch.zeros_like(noise)) data = torch.cat(cond_concat, dim=1) return data return None @@ -872,20 +882,17 @@ class HunyuanVideo(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - image = kwargs.get("concat_latent_image", None) - noise = kwargs.get("noise", None) - - if image is not None: - padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4]) - latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype) - image_latents = torch.cat([image.to(noise), latent_padding], dim=2) - out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents)) - guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class HunyuanVideoSkyreelsI2V(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + self.concat_keys = ("concat_image",) + + class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a8212c1fa..26340900b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -826,6 +826,16 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) +class HunyuanVideoSkyreelsI2V(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "in_channels": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideoSkyreelsI2V(self, device=device) + return out + class CosmosT2V(supported_models_base.BASE): unet_config = { "image_model": "cosmos", @@ -939,6 +949,6 @@ class WAN21_I2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models += [SVD_img2vid]