diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 59a62e0d..1b3e9f31 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -105,7 +105,9 @@ class Modulation(nn.Module): self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: - out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + if vec.ndim == 2: + vec = vec[:, None, :] + out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) return ( ModulationOut(*out[:3]), @@ -113,6 +115,20 @@ class Modulation(nn.Module): ) +def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): + if modulation_dims is None: + if m_add is not None: + return tensor * m_mult + m_add + else: + return tensor * m_mult + else: + for d in modulation_dims: + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + if m_add is not None: + tensor[:, d[0]:d[1]] += m_add[:, d[2]] + return tensor + + class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): super().__init__() @@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims) img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims) txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) @@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims) + img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims) # calculate the txt bloks - txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims) + txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: mod, _ = self.modulation(vec) - qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) @@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x += mod.gate * output + x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x @@ -252,8 +268,11 @@ class LastLayer(nn.Module): self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) - def forward(self, x: Tensor, vec: Tensor) -> Tensor: - shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: + if vec.ndim == 2: + vec = vec[:, None, :] + + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1) + x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims) x = self.linear(x) return x diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index f3f44584..001e302b 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + guiding_frame_index=None, control=None, transformer_options={}, ) -> Tensor: @@ -237,7 +238,15 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if guiding_frame_index is not None: + token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) + modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] + else: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + modulation_dims = None if self.params.guidance_embed: if guidance is not None: @@ -271,7 +280,7 @@ class HunyuanVideo(nn.Module): txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_i = control.get("input") @@ -292,7 +301,7 @@ class HunyuanVideo(nn.Module): out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_o = control.get("output") @@ -303,7 +312,7 @@ class HunyuanVideo(nn.Module): img = img[:, : img_len] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) shape = initial_shape[-3:] for i in range(len(shape)): @@ -313,7 +322,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -325,5 +334,5 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 2fa1ee91..bf4ebefa 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,20 +898,31 @@ class HunyuanVideo(BaseModel): guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + guiding_frame_index = kwargs.get("guiding_frame_index", None) + if guiding_frame_index is not None: + out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + return out + def scale_latent_inpaint(self, latent_image, **kwargs): + return latent_image class HunyuanVideoI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image", "mask_inverted") + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image",) + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 56aef9b0..504010ad 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -68,7 +68,6 @@ class TextEncodeHunyuanVideo_ImageToVideo: tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) return (clip.encode_from_tokens_scheduled(tokens), ) - class HunyuanImageToVideo: @classmethod def INPUT_TYPES(s): @@ -78,6 +77,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "guidance_type": (["v1 (concat)", "v2 (replace)"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -88,8 +88,10 @@ class HunyuanImageToVideo: CATEGORY = "conditioning/video_models" - def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -97,13 +99,20 @@ class HunyuanImageToVideo: mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + if guidance_type == "v1 (concat)": + cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} + else: + cond = {'guiding_frame_index': 0} + latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image + out_latent["noise_mask"] = mask + + positive = node_helpers.conditioning_set_values(positive, cond) - out_latent = {} out_latent["samples"] = latent return (positive, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,