Skip to content

Commit 5f20f3c

Browse files
committed
feat: support compile torchair graph while warming up
Signed-off-by: boying <897013703@qq.com>
1 parent 3442fbd commit 5f20f3c

File tree

7 files changed

+220
-225
lines changed

7 files changed

+220
-225
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ jobs:
108108
run: |
109109
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
110110
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
111-
# AscendScheduler doesn't work, fix it later
112-
# pytest -sv tests/singlecard/tets_schedule.py
111+
pytest -sv tests/singlecard/test_scheduler.py
113112
# guided decoding doesn't work, fix it later
114113
# pytest -sv tests/singlecard/test_guided_decoding.py.py
115114
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
@@ -124,8 +123,7 @@ jobs:
124123
run: |
125124
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
126125
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
127-
# AscendScheduler doesn't work, fix it later
128-
# pytest -sv tests/singlecard/tets_schedule.py
126+
pytest -sv tests/singlecard/test_scheduler.py
129127
# guided decoding doesn't work, fix it later
130128
# pytest -sv tests/singlecard/test_guided_decoding.py.py
131129
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py

tests/singlecard/test_scheduler.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ def create_scheduler(
8383
cache_dtype="auto",
8484
**kwargs_cache,
8585
)
86-
vllm_config = VllmConfig(
87-
scheduler_config=scheduler_config,
88-
model_config=model_config,
89-
cache_config=cache_config,
90-
)
86+
vllm_config = VllmConfig(scheduler_config=scheduler_config,
87+
model_config=model_config,
88+
cache_config=cache_config)
89+
9190
kv_cache_config = KVCacheConfig(
9291
num_blocks=10000, # A large number of blocks to hold all requests
9392
tensors={},
@@ -98,10 +97,7 @@ def create_scheduler(
9897
)
9998
cache_config.num_gpu_blocks = 10000
10099
return AscendScheduler(
101-
scheduler_config,
102-
model_config,
103-
cache_config,
104-
lora_config=None,
100+
vllm_config,
105101
kv_cache_config=kv_cache_config,
106102
log_stats=True,
107103
structured_output_manager=StructuredOutputManager(vllm_config),
@@ -128,14 +124,12 @@ def create_requests(num_requests: int,
128124
mm_inputs = None
129125
request = Request(
130126
request_id=f"{i}",
131-
prompt=None,
132127
prompt_token_ids=[i] * num_tokens,
133128
sampling_params=sampling_params,
134129
multi_modal_inputs=mm_inputs,
135130
multi_modal_placeholders=mm_position,
136131
multi_modal_hashes=None,
137132
eos_token_id=EOS_TOKEN_ID,
138-
arrival_time=0,
139133
)
140134
requests.append(request)
141135
return requests
@@ -225,12 +219,9 @@ def test_stop_via_update_from_output():
225219
requests[0].request_id: 1,
226220
requests[1].request_id: 2
227221
},
222+
scheduled_spec_decode_tokens={},
228223
total_num_scheduled_tokens=3,
229224
scheduled_encoder_inputs={},
230-
scheduled_spec_decode_tokens={
231-
requests[0].request_id: [],
232-
requests[1].request_id: [10]
233-
},
234225
num_common_prefix_blocks=0,
235226
finished_req_ids=set(),
236227
free_encoder_input_ids=[],
@@ -275,12 +266,9 @@ def test_stop_via_update_from_output():
275266
requests[0].request_id: 3,
276267
requests[1].request_id: 2
277268
},
269+
scheduled_spec_decode_tokens={},
278270
total_num_scheduled_tokens=5,
279271
scheduled_encoder_inputs={},
280-
scheduled_spec_decode_tokens={
281-
requests[0].request_id: [10, 42],
282-
requests[1].request_id: [13]
283-
},
284272
num_common_prefix_blocks=0,
285273
finished_req_ids=set(),
286274
free_encoder_input_ids=[],
@@ -323,12 +311,9 @@ def test_stop_via_update_from_output():
323311
requests[0].request_id: 3,
324312
requests[1].request_id: 1
325313
},
314+
scheduled_spec_decode_tokens={},
326315
total_num_scheduled_tokens=4,
327316
scheduled_encoder_inputs={},
328-
scheduled_spec_decode_tokens={
329-
requests[0].request_id: [10, 11],
330-
requests[1].request_id: []
331-
},
332317
num_common_prefix_blocks=0,
333318
finished_req_ids=set(),
334319
free_encoder_input_ids=[],
@@ -369,11 +354,9 @@ def test_stop_via_update_from_output():
369354
scheduled_new_reqs=[],
370355
scheduled_cached_reqs=[],
371356
num_scheduled_tokens={requests[0].request_id: 3},
357+
scheduled_spec_decode_tokens={},
372358
total_num_scheduled_tokens=3,
373359
scheduled_encoder_inputs={},
374-
scheduled_spec_decode_tokens={
375-
requests[0].request_id: [EOS_TOKEN_ID, 10]
376-
},
377360
num_common_prefix_blocks=0,
378361
finished_req_ids=set(),
379362
free_encoder_input_ids=[],

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,44 @@ def _get_graph_runner_block_tables(
241241
max_blocks] = block_tables[:num_seqs, :
242242
max_blocks]
243243

244-
return graph_block_tables
244+
return graph_block_tables[:num_seqs, :max_blocks]
245+
246+
def build_dummy(self, num_reqs: int,
247+
num_actual_tokens: int) -> AscendMLAMetadata:
248+
device = self.runner.device
249+
_, max_blocks = self.runner.graph_block_tables.shape
250+
block_table = torch.zeros((num_reqs, max_blocks),
251+
dtype=torch.int32,
252+
device=device)
253+
block_table = self._get_graph_runner_block_tables(
254+
num_reqs, block_table)
255+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
256+
input_positions = torch.zeros(num_reqs,
257+
dtype=torch.int32,
258+
device=device).long()
259+
slot_mapping = torch.full((num_reqs, ),
260+
PAD_SLOT_ID,
261+
dtype=torch.int32,
262+
device=device)
263+
decode_metadata = AscendMLADecodeMetadata(
264+
input_positions=input_positions,
265+
block_table=block_table,
266+
seq_lens=seq_lens,
267+
seq_lens_list=seq_lens.tolist(),
268+
max_seq_lens=1)
269+
return self.metadata_cls( # type: ignore
270+
num_input_tokens=num_actual_tokens,
271+
num_actual_tokens=num_actual_tokens,
272+
slot_mapping=slot_mapping,
273+
head_dim=self.runner.model_config.get_head_size(),
274+
num_decodes=1,
275+
num_decode_tokens=1,
276+
num_prefills=0,
277+
attn_mask=self.runner.attn_mask,
278+
attn_state=AscendAttentionState.DecodeOnly,
279+
prefill=None,
280+
decode=decode_metadata,
281+
)
245282

246283
def build(self,
247284
num_reqs: int,
@@ -324,7 +361,7 @@ def build(self,
324361
block_table = torch.cat([block_table, block_table_padding],
325362
dim=0)
326363
block_table = self._get_graph_runner_block_tables(
327-
num_seqs, block_table)
364+
num_seqs + graph_pad_size, block_table)
328365
padding_0 = torch.zeros(graph_pad_size,
329366
dtype=input_positions.dtype,
330367
device=input_positions.device)

0 commit comments

Comments
 (0)