|
| 1 | +from unittest.mock import patch |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 6 | + |
| 7 | +import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401 |
| 8 | + |
| 9 | + |
| 10 | +@pytest.fixture |
| 11 | +def dummy_tensor(): |
| 12 | + return torch.randn(4, 8, dtype=torch.float16) |
| 13 | + |
| 14 | + |
| 15 | +def mock_rms_norm(x, weight, eps): |
| 16 | + return x + 1, None |
| 17 | + |
| 18 | + |
| 19 | +def mock_add_rms_norm(x, residual, weight, eps): |
| 20 | + return 2 * x, None, 2 * residual |
| 21 | + |
| 22 | + |
| 23 | +@pytest.mark.parametrize("is_310p_return", [True, False]) |
| 24 | +@pytest.mark.parametrize("residual", |
| 25 | + [None, torch.randn(4, 8, dtype=torch.float32)]) |
| 26 | +@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) |
| 27 | +@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) |
| 28 | +def test_SiluAndMul_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, |
| 29 | + residual, dummy_tensor): |
| 30 | + |
| 31 | + with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): |
| 32 | + layer = RMSNorm(hidden_size=32, eps=1e-05) |
| 33 | + if residual is not None: |
| 34 | + out_x, out_residual = layer.forward(dummy_tensor, residual) |
| 35 | + |
| 36 | + if is_310p_return: |
| 37 | + expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype) |
| 38 | + expected_out_x = expected_arg_x + 1 |
| 39 | + expected_out_residual = expected_arg_x.to(residual.dtype) |
| 40 | + |
| 41 | + mock_rmsnorm.assert_called_once() |
| 42 | + assert torch.allclose(out_x, expected_out_x) |
| 43 | + assert torch.allclose(out_residual, expected_out_residual) |
| 44 | + else: |
| 45 | + expected_out_x = 2 * dummy_tensor |
| 46 | + expected_out_residual = 2 * residual |
| 47 | + mock_add_rmsnorm.assert_called_once() |
| 48 | + assert torch.allclose(out_x, expected_out_x) |
| 49 | + assert torch.allclose(out_residual, expected_out_residual) |
| 50 | + else: |
| 51 | + out_x = layer.forward(dummy_tensor, residual) |
| 52 | + expected_out_x = dummy_tensor + 1 |
| 53 | + |
| 54 | + mock_rmsnorm.assert_called_once() |
| 55 | + assert torch.allclose(out_x, expected_out_x) |
0 commit comments