Open
Description
apply_rotary_pos_emb
torchprime/torchprime/torch_xla_models/mixtral/model.py
Lines 113 to 138 in 13f3233
To reproduce, try the unit test pytest torchprime/torch_xla_models/tests/test_mixtral.py::test_forward_torch_xla_against_native
. We can see query_states.shape
updated from [2, 8, 4, 1] to [2, 8, 4, 2] after apply the embedding.