Skip to content

feat: support compile torchair graph while warming up #839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ jobs:
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
# AscendScheduler doesn't work, fix it later
# pytest -sv tests/singlecard/tets_schedule.py
pytest -sv tests/singlecard/test_scheduler.py
# guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
Expand All @@ -124,8 +123,7 @@ jobs:
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
# AscendScheduler doesn't work, fix it later
# pytest -sv tests/singlecard/tets_schedule.py
pytest -sv tests/singlecard/test_scheduler.py
# guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
Expand Down
66 changes: 31 additions & 35 deletions tests/singlecard/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.v1.structured_output import StructuredOutputManager

from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is

EOS_TOKEN_ID = 50256

Expand Down Expand Up @@ -83,11 +84,10 @@ def create_scheduler(
cache_dtype="auto",
**kwargs_cache,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
vllm_config = VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config)

kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
tensors={},
Expand All @@ -98,10 +98,7 @@ def create_scheduler(
)
cache_config.num_gpu_blocks = 10000
return AscendScheduler(
scheduler_config,
model_config,
cache_config,
lora_config=None,
vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
Expand All @@ -126,17 +123,27 @@ def create_requests(num_requests: int,
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt=None,
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
if vllm_version_is("0.9.0"):
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
arrival_time=0,
eos_token_id=EOS_TOKEN_ID,
)
else:
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
)
requests.append(request)
return requests

Expand Down Expand Up @@ -225,12 +232,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 1,
requests[1].request_id: 2
},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
Expand Down Expand Up @@ -275,12 +279,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 3,
requests[1].request_id: 2
},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
Expand Down Expand Up @@ -323,12 +324,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 3,
requests[1].request_id: 1
},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
Expand Down Expand Up @@ -369,11 +367,9 @@ def test_stop_via_update_from_output():
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
Expand Down
41 changes: 39 additions & 2 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,44 @@ def _get_graph_runner_block_tables(
max_blocks] = block_tables[:num_seqs, :
max_blocks]

return graph_block_tables
return graph_block_tables[:num_seqs, :max_blocks]

def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=1,
num_decode_tokens=1,
num_prefills=0,
attn_mask=self.runner.attn_mask,
attn_state=AscendAttentionState.DecodeOnly,
prefill=None,
decode=decode_metadata,
)

def build(self,
num_reqs: int,
Expand Down Expand Up @@ -324,7 +361,7 @@ def build(self,
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs, block_table)
num_seqs + graph_pad_size, block_table)
padding_0 = torch.zeros(graph_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
Expand Down
Loading