mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Speed up hunyuan dit inference a bit.
This commit is contained in:
parent
1b5b8ca81a
commit
ae197f651b
@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
|
|||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
@ -78,10 +78,9 @@ def apply_rotary_emb(
|
|||||||
xk_out = None
|
xk_out = None
|
||||||
if isinstance(freqs_cis, tuple):
|
if isinstance(freqs_cis, tuple):
|
||||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
if xk is not None:
|
if xk is not None:
|
||||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||||
else:
|
else:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||||
|
@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
|
|||||||
sub_args = [start, stop, (th, tw)]
|
sub_args = [start, stop, (th, tw)]
|
||||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||||
|
rope = (rope[0].to(x), rope[1].to(x))
|
||||||
return rope
|
return rope
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user