[BUG] Embedding dim changed after applying apply_rotary_pos_emb
#195
Labels
bug
Something isn't working
apply_rotary_pos_emb
#195
apply_rotary_pos_emb
torchprime/torchprime/torch_xla_models/mixtral/model.py
Lines 113 to 138 in 13f3233
To reproduce, try the unit test
pytest torchprime/torch_xla_models/tests/test_mixtral.py::test_forward_torch_xla_against_native
. We can seequery_states.shape
updated from [2, 8, 4, 1] to [2, 8, 4, 2] after apply the embedding.The text was updated successfully, but these errors were encountered: