Skip to content

[BUG] Embedding dim changed after applying apply_rotary_pos_emb #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zpcore opened this issue Apr 3, 2025 · 0 comments
Open

[BUG] Embedding dim changed after applying apply_rotary_pos_emb #195

zpcore opened this issue Apr 3, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@zpcore
Copy link
Collaborator

zpcore commented Apr 3, 2025

apply_rotary_pos_emb

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids.long()].unsqueeze(unsqueeze_dim)
sin = sin[position_ids.long()].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
updated the embedding dim for query and value I saw the embedding dim changed after applying the embedding.

To reproduce, try the unit test pytest torchprime/torch_xla_models/tests/test_mixtral.py::test_forward_torch_xla_against_native. We can see query_states.shape updated from [2, 8, 4, 1] to [2, 8, 4, 2] after apply the embedding.

@zpcore zpcore added the bug Something isn't working label Apr 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant