Skip to content

Commit 507ae62

Browse files
authored
feat: support compile torchair graph while warming up (#839)
### What this PR does / why we need it? feat: support compile torchair graph while warming up Signed-off-by: boying <897013703@qq.com>
1 parent d9fb027 commit 507ae62

File tree

7 files changed

+242
-234
lines changed

7 files changed

+242
-234
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ jobs:
108108
run: |
109109
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
110110
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
111-
# AscendScheduler doesn't work, fix it later
112-
# pytest -sv tests/singlecard/tets_schedule.py
111+
pytest -sv tests/singlecard/test_scheduler.py
113112
# guided decoding doesn't work, fix it later
114113
# pytest -sv tests/singlecard/test_guided_decoding.py.py
115114
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
@@ -124,8 +123,7 @@ jobs:
124123
run: |
125124
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
126125
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
127-
# AscendScheduler doesn't work, fix it later
128-
# pytest -sv tests/singlecard/tets_schedule.py
126+
pytest -sv tests/singlecard/test_scheduler.py
129127
# guided decoding doesn't work, fix it later
130128
# pytest -sv tests/singlecard/test_guided_decoding.py.py
131129
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py

tests/singlecard/test_scheduler.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.v1.structured_output import StructuredOutputManager
3232

3333
from vllm_ascend.core.scheduler import AscendScheduler
34+
from vllm_ascend.utils import vllm_version_is
3435

3536
EOS_TOKEN_ID = 50256
3637

@@ -83,11 +84,10 @@ def create_scheduler(
8384
cache_dtype="auto",
8485
**kwargs_cache,
8586
)
86-
vllm_config = VllmConfig(
87-
scheduler_config=scheduler_config,
88-
model_config=model_config,
89-
cache_config=cache_config,
90-
)
87+
vllm_config = VllmConfig(scheduler_config=scheduler_config,
88+
model_config=model_config,
89+
cache_config=cache_config)
90+
9191
kv_cache_config = KVCacheConfig(
9292
num_blocks=10000, # A large number of blocks to hold all requests
9393
tensors={},
@@ -98,10 +98,7 @@ def create_scheduler(
9898
)
9999
cache_config.num_gpu_blocks = 10000
100100
return AscendScheduler(
101-
scheduler_config,
102-
model_config,
103-
cache_config,
104-
lora_config=None,
101+
vllm_config,
105102
kv_cache_config=kv_cache_config,
106103
log_stats=True,
107104
structured_output_manager=StructuredOutputManager(vllm_config),
@@ -126,17 +123,27 @@ def create_requests(num_requests: int,
126123
else:
127124
mm_position = None
128125
mm_inputs = None
129-
request = Request(
130-
request_id=f"{i}",
131-
prompt=None,
132-
prompt_token_ids=[i] * num_tokens,
133-
sampling_params=sampling_params,
134-
multi_modal_inputs=mm_inputs,
135-
multi_modal_placeholders=mm_position,
136-
multi_modal_hashes=None,
137-
eos_token_id=EOS_TOKEN_ID,
138-
arrival_time=0,
139-
)
126+
if vllm_version_is("0.9.0"):
127+
request = Request(
128+
request_id=f"{i}",
129+
prompt_token_ids=[i] * num_tokens,
130+
sampling_params=sampling_params,
131+
multi_modal_inputs=mm_inputs,
132+
multi_modal_placeholders=mm_position,
133+
multi_modal_hashes=None,
134+
arrival_time=0,
135+
eos_token_id=EOS_TOKEN_ID,
136+
)
137+
else:
138+
request = Request(
139+
request_id=f"{i}",
140+
prompt_token_ids=[i] * num_tokens,
141+
sampling_params=sampling_params,
142+
multi_modal_inputs=mm_inputs,
143+
multi_modal_placeholders=mm_position,
144+
multi_modal_hashes=None,
145+
eos_token_id=EOS_TOKEN_ID,
146+
)
140147
requests.append(request)
141148
return requests
142149

@@ -225,12 +232,9 @@ def test_stop_via_update_from_output():
225232
requests[0].request_id: 1,
226233
requests[1].request_id: 2
227234
},
235+
scheduled_spec_decode_tokens={},
228236
total_num_scheduled_tokens=3,
229237
scheduled_encoder_inputs={},
230-
scheduled_spec_decode_tokens={
231-
requests[0].request_id: [],
232-
requests[1].request_id: [10]
233-
},
234238
num_common_prefix_blocks=0,
235239
finished_req_ids=set(),
236240
free_encoder_input_ids=[],
@@ -275,12 +279,9 @@ def test_stop_via_update_from_output():
275279
requests[0].request_id: 3,
276280
requests[1].request_id: 2
277281
},
282+
scheduled_spec_decode_tokens={},
278283
total_num_scheduled_tokens=5,
279284
scheduled_encoder_inputs={},
280-
scheduled_spec_decode_tokens={
281-
requests[0].request_id: [10, 42],
282-
requests[1].request_id: [13]
283-
},
284285
num_common_prefix_blocks=0,
285286
finished_req_ids=set(),
286287
free_encoder_input_ids=[],
@@ -323,12 +324,9 @@ def test_stop_via_update_from_output():
323324
requests[0].request_id: 3,
324325
requests[1].request_id: 1
325326
},
327+
scheduled_spec_decode_tokens={},
326328
total_num_scheduled_tokens=4,
327329
scheduled_encoder_inputs={},
328-
scheduled_spec_decode_tokens={
329-
requests[0].request_id: [10, 11],
330-
requests[1].request_id: []
331-
},
332330
num_common_prefix_blocks=0,
333331
finished_req_ids=set(),
334332
free_encoder_input_ids=[],
@@ -369,11 +367,9 @@ def test_stop_via_update_from_output():
369367
scheduled_new_reqs=[],
370368
scheduled_cached_reqs=[],
371369
num_scheduled_tokens={requests[0].request_id: 3},
370+
scheduled_spec_decode_tokens={},
372371
total_num_scheduled_tokens=3,
373372
scheduled_encoder_inputs={},
374-
scheduled_spec_decode_tokens={
375-
requests[0].request_id: [EOS_TOKEN_ID, 10]
376-
},
377373
num_common_prefix_blocks=0,
378374
finished_req_ids=set(),
379375
free_encoder_input_ids=[],

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,44 @@ def _get_graph_runner_block_tables(
241241
max_blocks] = block_tables[:num_seqs, :
242242
max_blocks]
243243

244-
return graph_block_tables
244+
return graph_block_tables[:num_seqs, :max_blocks]
245+
246+
def build_dummy(self, num_reqs: int,
247+
num_actual_tokens: int) -> AscendMLAMetadata:
248+
device = self.runner.device
249+
_, max_blocks = self.runner.graph_block_tables.shape
250+
block_table = torch.zeros((num_reqs, max_blocks),
251+
dtype=torch.int32,
252+
device=device)
253+
block_table = self._get_graph_runner_block_tables(
254+
num_reqs, block_table)
255+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
256+
input_positions = torch.zeros(num_reqs,
257+
dtype=torch.int32,
258+
device=device).long()
259+
slot_mapping = torch.full((num_reqs, ),
260+
PAD_SLOT_ID,
261+
dtype=torch.int32,
262+
device=device)
263+
decode_metadata = AscendMLADecodeMetadata(
264+
input_positions=input_positions,
265+
block_table=block_table,
266+
seq_lens=seq_lens,
267+
seq_lens_list=seq_lens.tolist(),
268+
max_seq_lens=1)
269+
return self.metadata_cls( # type: ignore
270+
num_input_tokens=num_actual_tokens,
271+
num_actual_tokens=num_actual_tokens,
272+
slot_mapping=slot_mapping,
273+
head_dim=self.runner.model_config.get_head_size(),
274+
num_decodes=1,
275+
num_decode_tokens=1,
276+
num_prefills=0,
277+
attn_mask=self.runner.attn_mask,
278+
attn_state=AscendAttentionState.DecodeOnly,
279+
prefill=None,
280+
decode=decode_metadata,
281+
)
245282

246283
def build(self,
247284
num_reqs: int,
@@ -324,7 +361,7 @@ def build(self,
324361
block_table = torch.cat([block_table, block_table_padding],
325362
dim=0)
326363
block_table = self._get_graph_runner_block_tables(
327-
num_seqs, block_table)
364+
num_seqs + graph_pad_size, block_table)
328365
padding_0 = torch.zeros(graph_pad_size,
329366
dtype=input_positions.dtype,
330367
device=input_positions.device)

0 commit comments

Comments
 (0)