Small cosmos attention code refactor. (#8530)

This commit is contained in:
comfyanonymous 2025-06-14 02:02:05 -07:00 committed by GitHub
parent 803af1e0c3
commit 29596bd53f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -70,11 +70,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
)
return result_B_S_HD
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):