Skip to content

Commit f688f36

Browse files
committed
add layernorm ut
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 834babe commit f688f36

File tree

6 files changed

+146
-10
lines changed

6 files changed

+146
-10
lines changed

tests/ut/ops/test_common_fused_moe.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
import torch
5+
from vllm.model_executor.layers.layernorm import RMSNorm
6+
7+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401
8+
9+
10+
@pytest.fixture
11+
def dummy_tensor():
12+
return torch.randn(4, 8, dtype=torch.float16)
13+
14+
15+
def mock_rms_norm(x, weight, eps):
16+
return x + 1, None
17+
18+
19+
def mock_add_rms_norm(x, residual, weight, eps):
20+
return 2 * x, None, 2 * residual
21+
22+
23+
@pytest.mark.parametrize("is_310p_return", [True, False])
24+
@pytest.mark.parametrize("residual",
25+
[None, torch.randn(4, 8, dtype=torch.float32)])
26+
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
27+
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
28+
def test_SiluAndMul_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
29+
residual, dummy_tensor):
30+
31+
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
32+
layer = RMSNorm(hidden_size=32, eps=1e-05)
33+
if residual is not None:
34+
out_x, out_residual = layer.forward(dummy_tensor, residual)
35+
36+
if is_310p_return:
37+
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
38+
expected_out_x = expected_arg_x + 1
39+
expected_out_residual = expected_arg_x.to(residual.dtype)
40+
41+
mock_rmsnorm.assert_called_once()
42+
assert torch.allclose(out_x, expected_out_x)
43+
assert torch.allclose(out_residual, expected_out_residual)
44+
else:
45+
expected_out_x = 2 * dummy_tensor
46+
expected_out_residual = 2 * residual
47+
mock_add_rmsnorm.assert_called_once()
48+
assert torch.allclose(out_x, expected_out_x)
49+
assert torch.allclose(out_residual, expected_out_residual)
50+
else:
51+
out_x = layer.forward(dummy_tensor, residual)
52+
expected_out_x = dummy_tensor + 1
53+
54+
mock_rmsnorm.assert_called_once()
55+
assert torch.allclose(out_x, expected_out_x)

tests/ut/ops/test_layernorm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
import torch
5+
from vllm.model_executor.layers.layernorm import RMSNorm
6+
7+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401
8+
9+
10+
@pytest.fixture
11+
def dummy_tensor():
12+
return torch.randn(4, 8, dtype=torch.float16)
13+
14+
15+
def mock_rms_norm(x, weight, eps):
16+
return x + 1, None
17+
18+
19+
def mock_add_rms_norm(x, residual, weight, eps):
20+
return 2 * x, None, 2 * residual
21+
22+
23+
@pytest.mark.parametrize("is_310p_return", [True, False])
24+
@pytest.mark.parametrize("residual",
25+
[None, torch.randn(4, 8, dtype=torch.float32)])
26+
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
27+
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
28+
def test_SiluAndMul_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
29+
residual, dummy_tensor):
30+
31+
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
32+
layer = RMSNorm(hidden_size=32, eps=1e-05)
33+
if residual is not None:
34+
out_x, out_residual = layer.forward(dummy_tensor, residual)
35+
36+
if is_310p_return:
37+
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
38+
expected_out_x = expected_arg_x + 1
39+
expected_out_residual = expected_arg_x.to(residual.dtype)
40+
41+
mock_rmsnorm.assert_called_once()
42+
assert torch.allclose(out_x, expected_out_x)
43+
assert torch.allclose(out_residual, expected_out_residual)
44+
else:
45+
expected_out_x = 2 * dummy_tensor
46+
expected_out_residual = 2 * residual
47+
mock_add_rmsnorm.assert_called_once()
48+
assert torch.allclose(out_x, expected_out_x)
49+
assert torch.allclose(out_residual, expected_out_residual)
50+
else:
51+
out_x = layer.forward(dummy_tensor, residual)
52+
expected_out_x = dummy_tensor + 1
53+
54+
mock_rmsnorm.assert_called_once()
55+
assert torch.allclose(out_x, expected_out_x)

tests/ut/ops/test_rotary_embedding.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,13 @@ def _create_layer(self):
221221

222222
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
223223
new_callable=PropertyMock)
224-
def test_native_rope_deepseek_forward_base(self, mock_current_platform):
224+
@patch("vllm_ascend.ops.rotary_embedding.current_platform",
225+
new_callable=PropertyMock)
226+
def test_native_rope_deepseek_forward_base(self,
227+
mock_current_platform_ascend,
228+
mock_current_platform):
225229
mock_current_platform.device_type = torch.device("cpu")
230+
mock_current_platform_ascend.device_type = torch.device("cpu")
226231
self.layer = self._create_layer()
227232
with patch("vllm_ascend.ops.rotary_embedding.rope_forward_oot",
228233
return_value=(self.query,
@@ -236,9 +241,13 @@ def test_native_rope_deepseek_forward_base(self, mock_current_platform):
236241
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
237242
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
238243
new_callable=PropertyMock)
244+
@patch("vllm_ascend.ops.rotary_embedding.current_platform",
245+
new_callable=PropertyMock)
239246
def test_native_rope_deepseek_forward_cache_handling(
240-
self, mock_current_platform, mock_rope_forward_oot):
247+
self, mock_current_platform_ascend, mock_current_platform,
248+
mock_rope_forward_oot):
241249
mock_current_platform.device_type = torch.device("cpu")
250+
mock_current_platform_ascend.device_type = torch.device("cpu")
242251
self.layer = self._create_layer()
243252
self.layer.max_seq_len = 1024
244253
# Test cache situation is true
@@ -256,9 +265,13 @@ def test_native_rope_deepseek_forward_cache_handling(
256265
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
257266
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
258267
new_callable=PropertyMock)
268+
@patch("vllm_ascend.ops.rotary_embedding.current_platform",
269+
new_callable=PropertyMock)
259270
def test_native_rope_deepseek_forward_key_reshaping(
260-
self, mock_current_platform, mock_rope_forward_oot):
271+
self, mock_current_platform_ascend, mock_current_platform,
272+
mock_rope_forward_oot):
261273
mock_current_platform.device_type = torch.device("cpu")
274+
mock_current_platform_ascend.device_type = torch.device("cpu")
262275
self.layer = self._create_layer()
263276

264277
key = torch.randn(1, 32)
@@ -273,9 +286,13 @@ def test_native_rope_deepseek_forward_key_reshaping(
273286
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
274287
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
275288
new_callable=PropertyMock)
289+
@patch("vllm_ascend.ops.rotary_embedding.current_platform",
290+
new_callable=PropertyMock)
276291
def test_native_rope_deepseek_forward_non_neox_style(
277-
self, mock_current_platform, mock_rope_forward_oot):
292+
self, mock_current_platform_ascend, mock_current_platform,
293+
mock_rope_forward_oot):
278294
mock_current_platform.device_type = torch.device("cpu")
295+
mock_current_platform_ascend.device_type = torch.device("cpu")
279296
self.layer = self._create_layer()
280297

281298
mock_rope_forward_oot.return_value = (self.query, self.key)
@@ -288,9 +305,13 @@ def test_native_rope_deepseek_forward_non_neox_style(
288305

289306
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
290307
new_callable=PropertyMock)
291-
def test_basic_case(self, mock_current_platform):
308+
@patch("vllm_ascend.ops.rotary_embedding.current_platform",
309+
new_callable=PropertyMock)
310+
def test_basic_case(self, mock_current_platform_ascend,
311+
mock_current_platform):
292312
# Test with standard values
293313
mock_current_platform.device_type = torch.device("cpu")
314+
mock_current_platform_ascend.device_type = torch.device("cpu")
294315
self.layer = self._create_layer()
295316
num_rotations = 100
296317
dim = 512
@@ -310,8 +331,12 @@ def test_basic_case(self, mock_current_platform):
310331

311332
@patch("vllm.model_executor.layers.rotary_embedding.current_platform",
312333
new_callable=PropertyMock)
313-
def test_yarn_get_mscale(self, mock_current_platform):
334+
@patch("vllm_ascend.ops.rotary_embedding.current_platform",
335+
new_callable=PropertyMock)
336+
def test_yarn_get_mscale(self, mock_current_platform_ascend,
337+
mock_current_platform):
314338
mock_current_platform.device_type = torch.device("cpu")
339+
mock_current_platform_ascend.device_type = torch.device("cpu")
315340
self.layer = self._create_layer()
316341

317342
# test_scale_less_than_or_equal_1

tests/ut/quantization/test_quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_get_quant_method_for_fused_moe(self):
114114

115115
# Test skipped layer
116116
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
117-
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
117+
patch('vllm_ascend.quantization.quant_config.AscendDSUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
118118
method = self.ascend_config.get_quant_method(
119119
fused_moe_layer, "moe_layer")
120120
self.assertIs(method, mock_ascend_moe.return_value)

vllm_ascend/ops/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import torch
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

23-
from vllm_ascend.utils import is_310p
24-
2523

2624
@RMSNorm.register_oot
2725
class AscendRMSNorm(RMSNorm):
@@ -33,6 +31,8 @@ def forward_oot(
3331
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
3432
import torch_npu
3533

34+
from vllm_ascend.utils import is_310p
35+
3636
if residual is not None:
3737
if is_310p():
3838
orig_dtype = residual.dtype

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from vllm.model_executor.layers.rotary_embedding import (
2323
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
24+
from vllm.platforms import current_platform
2425

2526
from vllm_ascend.ascend_config import get_ascend_config
2627
from vllm_ascend.utils import enable_custom_op, is_310p
@@ -141,7 +142,7 @@ def __init__(
141142
self.max_seq_len = max_position_embeddings
142143
self._set_cos_sin_cache(max_position_embeddings,
143144
dtype=dtype,
144-
device="npu")
145+
device=current_platform.device_type)
145146

146147
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
147148
if scale <= 1:

0 commit comments

Comments
 (0)