Skip to content

Commit 2498d29

Browse files
authored
add custom ascendc kernel vocabparallelembedding (#796)
This PR add custom ascendc kernel vocabparallelembedding support in vllm-ascend, related CMakeLists and setuptools is also added in this PR. pytest -s benchmarks/ops/ben_vocabparallelembedding.py pytest -s tests/ops/test_vocabparallelembedding.py --------- Signed-off-by: ttanzhiqiang <389825161@qq.com>
1 parent 3393d53 commit 2498d29

File tree

6 files changed

+710
-2
lines changed

6 files changed

+710
-2
lines changed

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,3 @@ target_link_libraries(
9696
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
9797

9898
install(TARGETS vllm_ascend_C vllm_ascend_kernels DESTINATION ${VLLM_ASCEND_INSTALL_PATH})
99-
100-
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from typing import Tuple
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
import torch_npu # noqa: F401
7+
import vllm # noqa: F401
8+
9+
import vllm_ascend.platform # noqa: F401
10+
11+
12+
def benchmark_npu(fn, num_iterations=100, num_warmup_iterations=50):
13+
"""
14+
Benchmark function for NPU operations
15+
16+
Args:
17+
fn: Function to benchmark
18+
num_iterations: Number of timing iterations
19+
num_warmup_iterations: Number of warmup iterations
20+
21+
Returns:
22+
float: Minimum elapsed time in seconds
23+
"""
24+
start = torch.npu.Event(enable_timing=True)
25+
end = torch.npu.Event(enable_timing=True)
26+
times = np.zeros(num_iterations + num_warmup_iterations)
27+
28+
# Run iterations
29+
for i in range(num_warmup_iterations + num_iterations):
30+
with torch.no_grad():
31+
start.record()
32+
fn() # Execute the function
33+
end.record()
34+
torch.npu.synchronize()
35+
times[i] = start.elapsed_time(end)
36+
37+
# Remove warmup iterations and convert to seconds
38+
times = times[num_warmup_iterations:]
39+
elapsed_time = np.amin(times) / 1000
40+
return elapsed_time
41+
42+
43+
def get_masked_input_and_mask_ref(
44+
input_: torch.Tensor, org_vocab_start_index: int,
45+
org_vocab_end_index: int, num_org_vocab_padding: int,
46+
added_vocab_start_index: int,
47+
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
48+
"""Reference implementation for verification"""
49+
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
50+
org_vocab_end_index)
51+
added_vocab_mask = (input_ >= added_vocab_start_index) & (
52+
input_ < added_vocab_end_index)
53+
added_offset = added_vocab_start_index - (
54+
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
55+
valid_offset = (org_vocab_start_index *
56+
org_vocab_mask) + (added_offset * added_vocab_mask)
57+
vocab_mask = org_vocab_mask | added_vocab_mask
58+
masked_input = vocab_mask * (input_ - valid_offset)
59+
return masked_input, ~vocab_mask
60+
61+
62+
DTYPES = [torch.int32]
63+
SHAPES = [(3, 4, 5)]
64+
DEVICES = [f"npu:{0}"]
65+
SEEDS = [0]
66+
67+
68+
@pytest.mark.parametrize("shape", SHAPES)
69+
@pytest.mark.parametrize("dtype", DTYPES)
70+
@pytest.mark.parametrize("device", DEVICES)
71+
@pytest.mark.parametrize("seed", SEEDS)
72+
@torch.inference_mode()
73+
def test_get_masked_input_and_mask(
74+
shape: Tuple[int, ...],
75+
dtype: torch.dtype,
76+
device: str,
77+
seed: int,
78+
) -> None:
79+
# Set random seed and device
80+
torch.manual_seed(seed)
81+
torch.set_default_device(device)
82+
83+
# Generate random input tensor
84+
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
85+
86+
# Test parameters
87+
test_case = {
88+
"org_start": 100,
89+
"org_end": 200,
90+
"padding": 0,
91+
"added_start": 300,
92+
"added_end": 400,
93+
}
94+
95+
# Define reference function
96+
def ref_fn():
97+
return get_masked_input_and_mask_ref(input_tensor,
98+
test_case["org_start"],
99+
test_case["org_end"],
100+
test_case["padding"],
101+
test_case["added_start"],
102+
test_case["added_end"])
103+
104+
# Define custom function
105+
def custom_fn():
106+
return torch.ops._C.get_masked_input_and_mask(input_tensor,
107+
test_case["org_start"],
108+
test_case["org_end"],
109+
test_case["padding"],
110+
test_case["added_start"],
111+
test_case["added_end"])
112+
113+
# Get results for correctness testing
114+
ref_masked_input, ref_mask = ref_fn()
115+
custom_masked_input, custom_mask = custom_fn()
116+
117+
# Benchmark both implementations
118+
ref_time = benchmark_npu(ref_fn)
119+
custom_time = benchmark_npu(custom_fn)
120+
121+
# Print performance results
122+
print("\nPerformance Results:")
123+
print(f"Reference implementation: {ref_time*1000:.3f} ms")
124+
print(f"Custom implementation: {custom_time*1000:.3f} ms")
125+
print(f"Speedup: {ref_time/custom_time:.2f}x")
126+
127+
# Compare results for correctness
128+
ref_masked_input = ref_masked_input.to(dtype)
129+
print("\nResults comparison:")
130+
print("custom_masked_input:", custom_masked_input)
131+
print("ref_masked_input:", ref_masked_input)
132+
print("custom_mask:", custom_mask)
133+
print("ref_mask:", ref_mask)
134+
torch.testing.assert_close(
135+
custom_masked_input,
136+
ref_masked_input,
137+
rtol=1e-5,
138+
atol=1e-5,
139+
msg=f"Masked input mismatch for case: {test_case}")
140+
torch.testing.assert_close(custom_mask,
141+
ref_mask,
142+
rtol=1e-5,
143+
atol=1e-5,
144+
msg=f"Mask mismatch for case: {test_case}")

0 commit comments

Comments
 (0)