Skip to content

Commit 31ce378

Browse files
committed
Some fixes
* register in setup entrypoints * add ut for activation Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 4500268 commit 31ce378

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ def _read_requirements(filename: str) -> List[str]:
391391
extras_require={},
392392
entry_points={
393393
"vllm.platform_plugins": ["ascend = vllm_ascend:register"],
394-
"vllm.general_plugins":
395-
["ascend_enhanced_model = vllm_ascend:register_model"],
394+
"vllm.general_plugins": [
395+
"ascend_enhanced_model = vllm_ascend:register_model",
396+
"dummy_custom_ops = vllm_ascend:register_ops"
397+
],
396398
},
397399
)

tests/ut/ops/test_activation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
import torch
5+
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
6+
7+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401
8+
9+
10+
@pytest.fixture
11+
def dummy_tensor():
12+
return torch.randn(4, 8, dtype=torch.float16)
13+
14+
15+
@patch("torch_npu.npu_fast_gelu", side_effect=lambda x: x + 1)
16+
def test_QuickGELU_forward(mock_gelu, dummy_tensor):
17+
layer = QuickGELU()
18+
out = layer.forward(dummy_tensor)
19+
20+
expected_out = dummy_tensor + 1
21+
assert torch.allclose(out, expected_out)
22+
23+
mock_gelu.assert_called_once()
24+
25+
26+
@pytest.mark.parametrize("is_310p_return", [True, False])
27+
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
28+
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
29+
30+
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
31+
layer = SiluAndMul()
32+
out = layer.forward(dummy_tensor)
33+
34+
if is_310p_return:
35+
expected_arg = dummy_tensor.to(torch.float32)
36+
else:
37+
expected_arg = dummy_tensor
38+
39+
# assert mock_swiglu.call_count == 1
40+
mock_swiglu.assert_called_once()
41+
42+
actual_arg = mock_swiglu.call_args[0][0]
43+
assert torch.allclose(
44+
actual_arg,
45+
expected_arg), "npu_swiglu called with unexpected input"
46+
47+
expected_out = dummy_tensor + 1
48+
assert torch.allclose(out, expected_out)

vllm_ascend/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ def register_model():
2929

3030
from .models import register_model
3131
register_model()
32+
33+
34+
def register_ops():
35+
import vllm_ascend.ops # noqa: F401

vllm_ascend/ops/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import torch
1919
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
2020

21-
from vllm_ascend.utils import is_310p
22-
2321

2422
@QuickGELU.register_oot
2523
class AscendQuickGELU(QuickGELU):
@@ -37,6 +35,8 @@ class AscendSiluAndMul(SiluAndMul):
3735
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3836
import torch_npu
3937

38+
from vllm_ascend.utils import is_310p
39+
4040
if is_310p():
4141
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
4242
else:

0 commit comments

Comments
 (0)