Skip to content

Commit 75c05ea

Browse files
committed
remove syncs and move back necessary sync
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent ca2f6b9 commit 75c05ea

File tree

3 files changed

+194
-79
lines changed

3 files changed

+194
-79
lines changed

requirements/test.txt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ argcomplete==3.5.1
2727
# via datamodel-code-generator
2828
arrow==1.3.0
2929
# via isoduration
30+
async-timeout==5.0.1
31+
# via
32+
# aiohttp
33+
# redis
3034
attrs==24.2.0
3135
# via
3236
# aiohttp
@@ -129,6 +133,11 @@ eval-type-backport==0.2.2
129133
# via mteb
130134
evaluate==0.4.3
131135
# via lm-eval
136+
exceptiongroup==1.3.0
137+
# via
138+
# anyio
139+
# hypothesis
140+
# pytest
132141
fastparquet==2024.11.0
133142
# via genai-perf
134143
fastrlock==0.8.2
@@ -641,7 +650,6 @@ setuptools==77.0.3
641650
# via
642651
# mamba-ssm
643652
# pytablewriter
644-
# torch
645653
# triton
646654
shellingham==1.5.4
647655
# via typer
@@ -701,8 +709,13 @@ tokenizers==0.21.1
701709
# via
702710
# -r requirements/test.in
703711
# transformers
712+
toml==0.10.2
713+
# via datamodel-code-generator
704714
tomli==2.2.1
705-
# via schemathesis
715+
# via
716+
# black
717+
# pytest
718+
# schemathesis
706719
tomli-w==1.2.0
707720
# via schemathesis
708721
torch==2.7.0+cu128
@@ -776,13 +789,18 @@ types-python-dateutil==2.9.0.20241206
776789
# via arrow
777790
typing-extensions==4.12.2
778791
# via
792+
# anyio
793+
# black
794+
# exceptiongroup
779795
# huggingface-hub
780796
# librosa
781797
# mistral-common
782798
# mteb
799+
# multidict
783800
# pqdm
784801
# pydantic
785802
# pydantic-core
803+
# rich
786804
# torch
787805
# typer
788806
# typing-inspection

vllm/v1/spec_decode/eagle.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def propose(
9090
cu_num_tokens: torch.Tensor,
9191
# [batch_size, max_num_blocks_per_req]
9292
block_table: torch.Tensor,
93+
max_seq_len: int,
94+
max_num_tokens: int,
9395
sampling_metadata: SamplingMetadata,
9496
) -> torch.Tensor:
9597
num_tokens = target_token_ids.shape[0]
@@ -113,10 +115,6 @@ def propose(
113115
seq_lens = (target_positions[last_token_indices] + 1).int()
114116

115117
if self.method in ["eagle", "eagle3"]:
116-
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
117-
max_seq_len = seq_lens.max().item()
118-
max_num_tokens = (cu_num_tokens[1:] -
119-
cu_num_tokens[:-1]).max().item()
120118
attn_metadata = FlashAttentionMetadata(
121119
num_actual_tokens=num_tokens,
122120
max_query_len=max_num_tokens,
@@ -133,9 +131,6 @@ def propose(
133131
suffix_kv_lens=None,
134132
)
135133
elif self.method == "deepseek_mtp":
136-
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
137-
max_query_len = query_lens.max().item()
138-
139134
common_attn_metadata = CommonAttentionMetadata(
140135
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
141136

@@ -145,7 +140,7 @@ def propose(
145140
attn_metadata = self.runner.attn_metadata_builder.build(
146141
num_reqs=batch_size,
147142
num_actual_tokens=num_tokens,
148-
max_query_len=max_query_len,
143+
max_query_len=max_num_tokens,
149144
common_prefix_len=0,
150145
common_attn_metadata=common_attn_metadata,
151146
)
@@ -298,7 +293,7 @@ def prepare_inputs(
298293
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
299294
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
300295
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
301-
token_indices = torch.empty(
296+
token_indices = torch.zeros(
302297
num_tokens,
303298
dtype=torch.int32,
304299
device=cu_target_query_lens.device,

0 commit comments

Comments
 (0)