Skip to content

Commit 75dc907

Browse files
committed
add layernorm ut
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent c10b3f8 commit 75dc907

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

tests/ut/ops/test_layernorm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
import torch
5+
from vllm.model_executor.layers.layernorm import RMSNorm
6+
7+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401
8+
9+
10+
@pytest.fixture
11+
def dummy_tensor():
12+
return torch.randn(4, 8, dtype=torch.float16)
13+
14+
15+
def mock_rms_norm(x, weight, eps):
16+
return x + 1, None
17+
18+
19+
def mock_add_rms_norm(x, residual, weight, eps):
20+
return 2 * x, None, 2 * residual
21+
22+
23+
@pytest.mark.parametrize("is_310p_return", [True, False])
24+
@pytest.mark.parametrize("residual",
25+
[None, torch.randn(4, 8, dtype=torch.float32)])
26+
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
27+
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
28+
def test_SiluAndMul_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
29+
residual, dummy_tensor):
30+
31+
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
32+
layer = RMSNorm(hidden_size=32, eps=1e-05)
33+
if residual is not None:
34+
out_x, out_residual = layer.forward(dummy_tensor, residual)
35+
36+
if is_310p_return:
37+
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
38+
expected_out_x = expected_arg_x + 1
39+
expected_out_residual = expected_arg_x.to(residual.dtype)
40+
41+
mock_rmsnorm.assert_called_once()
42+
assert torch.allclose(out_x, expected_out_x)
43+
assert torch.allclose(out_residual, expected_out_residual)
44+
else:
45+
expected_out_x = 2 * dummy_tensor
46+
expected_out_residual = 2 * residual
47+
mock_add_rmsnorm.assert_called_once()
48+
assert torch.allclose(out_x, expected_out_x)
49+
assert torch.allclose(out_residual, expected_out_residual)
50+
else:
51+
out_x = layer.forward(dummy_tensor, residual)
52+
expected_out_x = dummy_tensor + 1
53+
54+
mock_rmsnorm.assert_called_once()
55+
assert torch.allclose(out_x, expected_out_x)

vllm_ascend/ops/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import torch
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

23-
from vllm_ascend.utils import is_310p
24-
2523

2624
@RMSNorm.register_oot
2725
class AscendRMSNorm(RMSNorm):
@@ -33,6 +31,8 @@ def forward_oot(
3331
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
3432
import torch_npu
3533

34+
from vllm_ascend.utils import is_310p
35+
3636
if residual is not None:
3737
if is_310p():
3838
orig_dtype = residual.dtype

0 commit comments

Comments
 (0)