From ae197f651b07389bfb778b690575043205a9a5c5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Aug 2024 07:36:27 -0400 Subject: [PATCH] Speed up hunyuan dit inference a bit. --- comfy/ldm/hydit/attn_layers.py | 7 +++---- comfy/ldm/hydit/models.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hydit/attn_layers.py b/comfy/ldm/hydit/attn_layers.py index 920b8428..e2801f71 100644 --- a/comfy/ldm/hydit/attn_layers.py +++ b/comfy/ldm/hydit/attn_layers.py @@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x def rotate_half(x): - x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] return torch.stack([-x_imag, x_real], dim=-1).flatten(3) @@ -78,10 +78,9 @@ def apply_rotary_emb( xk_out = None if isinstance(freqs_cis, tuple): cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] - cos, sin = cos.to(xq.device), sin.to(xq.device) - xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xq_out = (xq * cos + rotate_half(xq) * sin) if xk is not None: - xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + xk_out = (xk * cos + rotate_half(xk) * sin) else: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 9a1f3733..f3afaad3 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size): sub_args = [start, stop, (th, tw)] # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] rope = get_2d_rotary_pos_embed(head_size, *sub_args) + rope = (rope[0].to(x), rope[1].to(x)) return rope