diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 91c2ad40df..4c2b2aed4f 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -1,3 +1,5 @@ +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + import math import unittest from unittest.mock import MagicMock, patch @@ -71,10 +73,10 @@ class TestRopeForwardOot(unittest.TestCase): def setUp(self): # Common setup for tests - self.positions = torch.tensor([1, 2, 3]) + self.positions = torch.tensor([0, 1, 2]) self.query = torch.randn(3, 4, dtype=torch.float16) self.key = torch.randn(3, 4, dtype=torch.float16) - self.head_size = 32 + self.head_size = 4 self.cos_sin_cache = torch.randn(3, 4) # Mock self object for rope_forward_oot @@ -85,18 +87,18 @@ def setUp(self): self.mock_self.forward_native.return_value = (self.query, self.key) @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') - def test_rope_forward_oot_torchair_enabled_base(self, + @patch('torch_npu.npu_apply_rotary_pos_emb') + def test_rope_forward_oot_torchair_enabled_base(self, mock_rotary_emb, mock_get_ascend_config): # Setup mock for torchair enabled mock_config = MagicMock() mock_config.torchair_graph_config.enabled = True mock_get_ascend_config.return_value = mock_config + mock_rotary_emb.return_value = self.query, self.key result_q, result_k = rope_forward_oot(self.mock_self, self.positions, self.query, self.key) - self.mock_self.forward_native.assert_called_once_with( - self.positions, self.query, self.key, None) self.assertTrue(torch.equal(result_q, self.query)) self.assertTrue(torch.equal(result_k, self.key)) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3dd91ea63f..ea6377e584 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -39,16 +39,27 @@ def rope_forward_oot( offsets: Optional[torch.Tensor] = None, is_neox_style_override: Optional[bool] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - if get_ascend_config().torchair_graph_config.enabled: - return self.forward_native( - positions, - query, - key, - offsets, - ) - import torch_npu query_shape, key_shape = query.shape, key.shape + if get_ascend_config().torchair_graph_config.enabled: + positions = positions.flatten() + cos_cache, sin_cache = self.cos_sin_cache.chunk(2, dim=-1) + + cos_part = torch.index_select(cos_cache, 0, positions) + sin_part = torch.index_select(sin_cache, 0, positions) + + cos_sin = torch.cat([cos_part, cos_part, sin_part, sin_part], + dim=-1).unsqueeze(0).unsqueeze(2) + cos, sin = cos_sin.chunk(2, dim=-1) + # must BSND ? + query = query.reshape(positions.size(0), -1, + self.head_size).unsqueeze(0) + key = key.reshape(positions.size(0), -1, self.head_size).unsqueeze(0) + + query, key = torch_npu.npu_apply_rotary_pos_emb( + query, key, cos.contiguous(), sin.contiguous()) + + return query.view(query_shape), key.view(key_shape) if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fd40d13bc5..ff28929c32 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2016,10 +2016,15 @@ def load_model(self) -> None: from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) + + from vllm_ascend.models.pangu_moe import ( + CustomMergedColumnParallelLinear, CustomRowParallelLinear) for module in self.model.modules(): - if isinstance(module, - (MergedColumnParallelLinear, - QKVParallelLinear, RowParallelLinear)): + if isinstance( + module, + (MergedColumnParallelLinear, QKVParallelLinear, + RowParallelLinear, CustomMergedColumnParallelLinear, + CustomRowParallelLinear)): module.weight.data = torch_npu.npu_format_cast( module.weight.data, ACL_FORMAT_FRACTAL_NZ)