Skip to content

Commit 5cf9ff1

Browse files
[Performance]: Custom AscendC Kernel of Multi-Step Prepare Input (#814)
### What this PR does / why we need it? - According to #807, we pull request for customer ascendc kernel of multi-step. - also a bug we found in multi_step_runner.py is fixed when we use multi-step on V0 Engine. ### Does this PR introduce _any_ user-facing change? no user-facing change ### How was this patch tested? we add Unit Test file and offline inference file to test the custom ascendc kernel. See test/ops/test_multi_step.py and examples/offline_multi_step.py --------- Signed-off-by: wan_danfeng <wonderful199082@126.com>
1 parent 00e0243 commit 5cf9ff1

File tree

11 files changed

+629
-35
lines changed

11 files changed

+629
-35
lines changed

.github/workflows/codespell.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
- name: Run codespell check
4343
run: |
4444
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
45-
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue')
45+
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn')
4646
4747
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ endif()
4545

4646
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
4747
file(GLOB KERNEL_FILES
48-
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/pos_encoding_kernels.cpp)
48+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
4949

5050
ascendc_library(vllm_ascend_kernels SHARED
5151
${KERNEL_FILES}

csrc/kernels/advance_step.cpp

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "kernel_operator.h"
18+
constexpr int32_t BUFFER_NUM = 1;
19+
class KernelAdvanceStep{
20+
public:
21+
__aicore__ inline KernelAdvanceStep() {}
22+
__aicore__ inline void Init(int32_t tasks_per_core,
23+
int32_t num_queries,
24+
__gm__ int64_t* input_tokens_ptr,
25+
__gm__ int64_t* sampled_token_ids_ptr,
26+
__gm__ int64_t* input_positions_ptr,
27+
__gm__ int32_t* seq_lens_ptr,
28+
__gm__ int32_t* slot_mapping_ptr)
29+
{
30+
this->tasks_per_core = tasks_per_core;
31+
32+
this->start_id = this->tasks_per_core * AscendC::GetBlockIdx();
33+
this->end_id = this->tasks_per_core * (AscendC::GetBlockIdx() + 1) - 1;
34+
35+
// actual task nums of each core
36+
this->actual_task_per_core = tasks_per_core;
37+
if(this->end_id >= num_queries) {
38+
this->actual_task_per_core = num_queries - this->start_id;
39+
this->end_id = num_queries - 1;
40+
}
41+
42+
int32_t offset_this_core = this->tasks_per_core * AscendC::GetBlockIdx();
43+
44+
// init outQues
45+
pipe.InitBuffer(outQueInputTokens, BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t));
46+
pipe.InitBuffer(outQueInputPos, BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t));
47+
pipe.InitBuffer(outQueSeqLen, BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t));
48+
pipe.InitBuffer(outQueSlotMapping, BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t));
49+
50+
// init inQues
51+
pipe.InitBuffer(inQueSeqLen,BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t));
52+
pipe.InitBuffer(inQueSampledTokenIds,BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t));
53+
54+
// init GlobalMemory
55+
inputTokensGm.SetGlobalBuffer((__gm__ int64_t *)input_tokens_ptr + offset_this_core, this->actual_task_per_core);
56+
sampledTokenIdsGm.SetGlobalBuffer((__gm__ int64_t *)sampled_token_ids_ptr + offset_this_core, this->actual_task_per_core);
57+
inputPositionsGm.SetGlobalBuffer((__gm__ int64_t *)input_positions_ptr + offset_this_core, this->actual_task_per_core);
58+
seqLensGm.SetGlobalBuffer((__gm__ int32_t *)seq_lens_ptr + offset_this_core, this->actual_task_per_core);
59+
slotMappingGm.SetGlobalBuffer((__gm__ int32_t *)slot_mapping_ptr + offset_this_core, this->actual_task_per_core);
60+
}
61+
__aicore__ inline void Process(int64_t block_size, __gm__ int32_t* block_tables_ptr, int64_t block_tables_stride)
62+
{
63+
// no need for tilling or pipeline parallel within each core, as the amount of data processed is very small
64+
CopyIn();
65+
Update(block_size, block_tables_ptr, block_tables_stride);
66+
CopyOut();
67+
}
68+
69+
private:
70+
__aicore__ inline void CopyIn()
71+
{
72+
AscendC::LocalTensor<int32_t> seqLenLocalIn = inQueSeqLen.AllocTensor<int32_t>();
73+
AscendC::LocalTensor<int64_t> sampledTokenIdsLocal = inQueSampledTokenIds.AllocTensor<int64_t>();
74+
75+
AscendC::DataCopyExtParams copyParams32{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int32_t)), 0, 0, 0}; // blockLen = tasks_per_core * 32 / 8 个字节(int32为4字节)
76+
AscendC::DataCopyExtParams copyParams64{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int64_t)), 0, 0, 0}; // blockLen = tasks_per_core * 64 / 8 个字节(int64为8字节)
77+
78+
// calculate the nums that need padded
79+
// so that the total length becomes a multiple of 32 bytes which is a requirement of DataCopy Function.
80+
uint8_t remainNum32 =this->actual_task_per_core * sizeof(int32_t) % 32;
81+
uint8_t needPadElements32 = remainNum32 == 0 ? remainNum32 : (32 - remainNum32) / sizeof(int32_t);
82+
83+
AscendC::DataCopyPadExtParams<int32_t> padParams32{true, 0, needPadElements32, 0};
84+
85+
// calculate the nums that need padded
86+
// so that the total length becomes a multiple of 32 bytes which is a requirement of DataCopy Function.
87+
uint8_t remainNum64 =this->actual_task_per_core * sizeof(int64_t) % 32;
88+
uint8_t needPadElements64 = remainNum64 == 0 ? remainNum64 : (32 - remainNum64) / sizeof(int64_t);
89+
AscendC::DataCopyPadExtParams<int64_t> padParams64{true, 0, needPadElements64, 0};
90+
91+
AscendC::DataCopyPad(seqLenLocalIn, seqLensGm, copyParams32, padParams32);
92+
AscendC::DataCopyPad(sampledTokenIdsLocal, sampledTokenIdsGm, copyParams64, padParams64);
93+
94+
inQueSeqLen.EnQue(seqLenLocalIn);
95+
inQueSampledTokenIds.EnQue(sampledTokenIdsLocal);
96+
}
97+
__aicore__ inline void Update(int64_t block_size, __gm__ int32_t* block_tables_ptr, int64_t block_tables_stride)
98+
{
99+
// input
100+
AscendC::LocalTensor<int32_t> seqLenLocalIn = inQueSeqLen.DeQue<int32_t>();
101+
AscendC::LocalTensor<int64_t> sampledTokenIdsLocal = inQueSampledTokenIds.DeQue<int64_t>();
102+
103+
// output
104+
AscendC::LocalTensor<int64_t> inputTokensLocal = outQueInputTokens.AllocTensor<int64_t>();
105+
AscendC::LocalTensor<int64_t> inputPosLocal = outQueInputPos.AllocTensor<int64_t>();
106+
AscendC::LocalTensor<int32_t> seqLenLocalOut = outQueSeqLen.AllocTensor<int32_t>();
107+
AscendC::LocalTensor<int32_t> slotMappingLocal = outQueSlotMapping.AllocTensor<int32_t>();
108+
109+
auto unary_params = AscendC::UnaryRepeatParams(1, 1, 8, 8);
110+
111+
//Use "for" instead of AscendC::Adds function because AscendC::Adds does not work
112+
//when srcLocalMemory has different datatype from dstLocalMemory
113+
for(int i=0; i < this->actual_task_per_core; i++) {
114+
inputTokensLocal.SetValue(i, sampledTokenIdsLocal.GetValue(i));
115+
inputPosLocal.SetValue(i, seqLenLocalIn.GetValue(i));
116+
}
117+
118+
AscendC::Adds<int32_t, false>(seqLenLocalOut, seqLenLocalIn, 1, (uint64_t)0, 1, unary_params);
119+
120+
// Gather blockTables with dim=1, block_index. No Ascend Function available, use "for" instead.
121+
for(int cur_query_id = this->start_id, i = 0; i < this->actual_task_per_core; cur_query_id++, i++) {
122+
__gm__ int32_t const* seq_block_tables_ptr = block_tables_ptr + block_tables_stride * cur_query_id;
123+
124+
int block_index = inputPosLocal.GetValue(i) / block_size;
125+
int block_offset = inputPosLocal.GetValue(i) % block_size;
126+
127+
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
128+
// Update slot_mapping
129+
slotMappingLocal.SetValue(i,slot_num);
130+
}
131+
132+
outQueInputTokens.EnQue(inputTokensLocal);
133+
outQueInputPos.EnQue(inputPosLocal);
134+
outQueSeqLen.EnQue(seqLenLocalOut);
135+
outQueSlotMapping.EnQue(slotMappingLocal);
136+
137+
inQueSampledTokenIds.FreeTensor(sampledTokenIdsLocal);
138+
inQueSeqLen.FreeTensor(seqLenLocalIn);
139+
140+
}
141+
__aicore__ inline void CopyOut()
142+
{
143+
AscendC::DataCopyExtParams copyParams32{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int32_t)),0,0,0};
144+
AscendC::DataCopyExtParams copyParams64{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int64_t)),0,0,0};
145+
146+
AscendC::LocalTensor<int64_t> inputTokensLocal = outQueInputTokens.DeQue<int64_t>();
147+
AscendC::DataCopyPad(inputTokensGm, inputTokensLocal, copyParams64);
148+
outQueInputTokens.FreeTensor(inputTokensLocal);
149+
150+
AscendC::LocalTensor<int64_t> inputPosLocal = outQueInputPos.DeQue<int64_t>();
151+
AscendC::DataCopyPad(inputPositionsGm, inputPosLocal, copyParams64);
152+
outQueInputPos.FreeTensor(inputPosLocal);
153+
154+
AscendC::LocalTensor<int32_t> seqLenLocalOut = outQueSeqLen.DeQue<int32_t>();
155+
AscendC::DataCopyPad(seqLensGm, seqLenLocalOut, copyParams32);
156+
outQueSeqLen.FreeTensor(seqLenLocalOut);
157+
158+
AscendC::LocalTensor<int32_t> slotMappingLocal = outQueSlotMapping.DeQue<int32_t>();
159+
AscendC::DataCopyPad(slotMappingGm, slotMappingLocal, copyParams32);
160+
outQueSlotMapping.FreeTensor(slotMappingLocal);
161+
}
162+
163+
private:
164+
AscendC::TPipe pipe;
165+
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueInputTokens, outQueInputPos,
166+
outQueSeqLen, outQueSlotMapping;
167+
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueSeqLen,
168+
inQueSampledTokenIds,
169+
inQueBlockTables;
170+
171+
AscendC::GlobalTensor<int64_t> inputTokensGm, sampledTokenIdsGm, inputPositionsGm ;
172+
173+
AscendC::GlobalTensor<int32_t> seqLensGm, slotMappingGm, blockTablesGm;
174+
175+
int32_t tasks_per_core, start_id, end_id, actual_task_per_core;
176+
};
177+
178+
extern "C" __global__ __aicore__ void AdvanceStepFlashAttnKernel(
179+
int64_t num_seqs,
180+
int64_t num_queries,
181+
int64_t block_size,
182+
__gm__ int64_t* input_tokens_ptr,
183+
__gm__ int64_t* sampled_token_ids_ptr,
184+
__gm__ int64_t* input_positions_ptr,
185+
__gm__ int32_t* seq_lens_ptr,
186+
__gm__ int32_t* slot_mapping_ptr,
187+
__gm__ int32_t* block_tables_ptr,
188+
int64_t block_tables_stride,
189+
int32_t tasks_per_core
190+
)
191+
{
192+
int start_id = tasks_per_core * AscendC::GetBlockIdx();
193+
// no task for this core.
194+
if(start_id >= num_queries) {
195+
return;
196+
}
197+
KernelAdvanceStep advanceStep;
198+
advanceStep.Init(tasks_per_core, num_queries, input_tokens_ptr, sampled_token_ids_ptr, input_positions_ptr, seq_lens_ptr, slot_mapping_ptr);
199+
advanceStep.Process(block_size,block_tables_ptr,block_tables_stride);
200+
}
201+
202+
namespace vllm_ascend
203+
{
204+
205+
extern void launch_advance_step_flashattn(
206+
void* stream,
207+
int64_t num_seqs,
208+
int64_t num_queries,
209+
int64_t block_size,
210+
int64_t* input_tokens_ptr,
211+
int64_t* sampled_token_ids_ptr,
212+
int64_t* input_positions_ptr,
213+
int32_t* seq_lens_ptr,
214+
int32_t* slot_mapping_ptr,
215+
int32_t* block_tables_ptr,
216+
int64_t block_tables_stride)
217+
{
218+
int32_t num_cores = 20;
219+
220+
if(num_cores > num_queries) {
221+
num_cores = num_queries;
222+
}
223+
224+
// task num processed of each core
225+
int32_t tasks_per_core = (num_queries + num_cores - 1) / num_cores;
226+
227+
AdvanceStepFlashAttnKernel<<<num_cores, nullptr, stream>>>(
228+
num_seqs,
229+
num_queries,
230+
block_size,
231+
input_tokens_ptr,
232+
sampled_token_ids_ptr,
233+
input_positions_ptr,
234+
seq_lens_ptr,
235+
slot_mapping_ptr,
236+
block_tables_ptr,
237+
block_tables_stride,
238+
tasks_per_core);
239+
}
240+
241+
}

csrc/ops.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,16 @@ namespace vllm_ascend {
4646
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
4747
return new_tensor;
4848
}
49+
extern void launch_advance_step_flashattn(
50+
void* stream,
51+
int64_t num_seqs,
52+
int64_t num_queries,
53+
int64_t block_size,
54+
int64_t* input_tokens_ptr,
55+
int64_t* sampled_token_ids_ptr,
56+
int64_t* input_positions_ptr,
57+
int32_t* seq_lens_ptr,
58+
int32_t* slot_mapping_ptr,
59+
int32_t* block_tables_ptr,
60+
int64_t block_tables_stride);
4961
}

csrc/torch_binding.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,87 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
9898
cmd.Run();
9999
return {query_dst, key_dst};
100100
}
101+
102+
void verify_tensor(std::string const& name, at::Tensor const& t,
103+
int64_t const size_0, int64_t const size_1,
104+
c10::ScalarType const type) {
105+
bool size_0_cond = true;
106+
if (size_0 != -1) {
107+
size_0_cond = t.size(0) == size_0;
108+
}
109+
110+
bool size_1_cond = true;
111+
if (size_1 != -1) {
112+
size_1_cond = t.size(1) == size_1;
113+
}
114+
115+
bool is_contiguous = t.is_contiguous();
116+
bool same_type = t.dtype() == type;
117+
118+
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
119+
if (!pass) {
120+
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
121+
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
122+
" is not as expected: shape = [", size_0, ", ", size_1,
123+
"], type = ", type);
124+
}
125+
}
126+
127+
128+
void advance_step_flashattn_ascendc(
129+
int64_t num_seqs, int64_t num_queries, int64_t block_size,
130+
at::Tensor& input_tokens,
131+
at::Tensor& sampled_token_ids,
132+
at::Tensor& input_positions,
133+
at::Tensor& seq_lens,
134+
at::Tensor& slot_mapping,
135+
at::Tensor& block_tables
136+
){
137+
// Verify all tensors
138+
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
139+
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,at::kLong);
140+
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
141+
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
142+
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kInt);
143+
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
144+
145+
146+
int64_t* input_tokens_ptr = input_tokens.data_ptr<int64_t>();
147+
int64_t* sampled_token_ids_ptr = sampled_token_ids.data_ptr<int64_t>();
148+
int64_t* input_positions_ptr = input_positions.data_ptr<int64_t>();
149+
int32_t* seq_lens_ptr = seq_lens.data_ptr<int32_t>();
150+
int32_t* slot_mapping_ptr = slot_mapping.data_ptr<int32_t>();
151+
int32_t* block_tables_ptr = block_tables.data_ptr<int32_t>();
152+
153+
154+
int32_t device_id;
155+
aclrtGetDevice(&device_id);
156+
auto npu_stream = c10_npu::getCurrentNPUStream(device_id);
157+
aclrtStream stream = npu_stream.stream();
158+
159+
// aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
160+
at_npu::native::OpCommand cmd;
161+
cmd.Name("advance_step_flashattn_ascendc");
162+
cmd.SetCustomHandler([stream, num_seqs, num_queries,
163+
block_size, input_tokens_ptr, sampled_token_ids_ptr,
164+
input_positions_ptr, seq_lens_ptr, slot_mapping_ptr,
165+
block_tables_ptr, block_tables]() -> int {
166+
launch_advance_step_flashattn(stream,
167+
num_seqs,
168+
num_queries,
169+
block_size,
170+
input_tokens_ptr,
171+
sampled_token_ids_ptr,
172+
input_positions_ptr,
173+
seq_lens_ptr,
174+
slot_mapping_ptr,
175+
block_tables_ptr,
176+
block_tables.stride(0));
177+
return 0;
178+
});
179+
cmd.Run();
180+
return ;
181+
}
101182
} // namespace vllm_ascend
102183

103184
TORCH_LIBRARY_EXPAND(_C, ops)
@@ -113,6 +194,11 @@ TORCH_LIBRARY_EXPAND(_C, ops)
113194
" Tensor! key, int head_size,"
114195
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
115196
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
197+
ops.def(
198+
"advance_step_flashattn_ascendc(int num_seqs, int num_queries, int block_size,"
199+
" Tensor! input_tokens, Tensor! sampled_token_ids, Tensor! input_positions,"
200+
" Tensor! seq_lens, Tensor! slot_mapping, Tensor! block_tables) -> ()");
201+
ops.impl("advance_step_flashattn_ascendc", torch::kPrivateUse1, &vllm_ascend::advance_step_flashattn_ascendc);
116202
}
117203

118204
REGISTER_EXTENSION(_C)

0 commit comments

Comments
 (0)