From 4a9014e201b71dfcf97d59d1086cb451cd40873f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 May 2025 12:53:47 -0700 Subject: [PATCH] Hunyuan Custom initial untested implementation. (#8101) --- comfy/ldm/hunyuan_video/model.py | 21 ++++++++++++++++++--- comfy/model_base.py | 4 ++++ comfy_extras/nodes_hunyuan.py | 6 ++++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 72af3d5bb..fbd8d4196 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module): y: Tensor, guidance: Tensor = None, guiding_frame_index=None, + ref_latent=None, control=None, transformer_options={}, ) -> Tensor: @@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if ref_latent is not None: + ref_latent_ids = self.img_ids(ref_latent) + ref_latent = self.img_in(ref_latent) + img = torch.cat([ref_latent, img], dim=-2) + ref_latent_ids[..., 0] = -1 + ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) + img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) + if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) @@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module): img[:, : img_len] += add img = img[:, : img_len] + if ref_latent is not None: + img = img[:, ref_latent.shape[1]:] img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) @@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): + def img_ids(self, x): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + img_ids = self.img_ids(x) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 6d27930dc..047861593 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -924,6 +924,10 @@ class HunyuanVideo(BaseModel): if guiding_frame_index is not None: out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + ref_latent = kwargs.get("ref_latent", None) + if ref_latent is not None: + out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent)) + return out def scale_latent_inpaint(self, latent_image, **kwargs): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 504010ad0..d7278e7a7 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -77,7 +77,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "guidance_type": (["v1 (concat)", "v2 (replace)"], ) + "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -101,10 +101,12 @@ class HunyuanImageToVideo: if guidance_type == "v1 (concat)": cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} - else: + elif guidance_type == "v2 (replace)": cond = {'guiding_frame_index': 0} latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image out_latent["noise_mask"] = mask + elif guidance_type == "custom": + cond = {"ref_latent": concat_latent_image} positive = node_helpers.conditioning_set_values(positive, cond)