Skip to content

Commit cf3796e

Browse files
fast build
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 2b21e0d commit cf3796e

File tree

10 files changed

+50
-25
lines changed

10 files changed

+50
-25
lines changed

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,10 @@ def reorder_batch(self, input_batch: InputBatch,
119119

120120
return True
121121

122-
def build(self, common_prefix_len: int,
123-
common_attn_metadata: CommonAttentionMetadata):
122+
def build(self,
123+
common_prefix_len: int,
124+
common_attn_metadata: CommonAttentionMetadata,
125+
fast_build: bool = False) -> TorchSDPAMetadata:
124126
num_reqs = common_attn_metadata.num_reqs
125127
num_actual_tokens = common_attn_metadata.num_actual_tokens
126128
max_query_len = common_attn_metadata.max_query_len

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
196196
# populated on first build() call.
197197
self.aot_sliding_window: Optional[tuple[int, int]] = None
198198

199-
def build(
200-
self,
201-
common_prefix_len: int,
202-
common_attn_metadata: CommonAttentionMetadata,
203-
) -> FlashAttentionMetadata:
199+
def build(self,
200+
common_prefix_len: int,
201+
common_attn_metadata: CommonAttentionMetadata,
202+
fast_build: bool = False) -> FlashAttentionMetadata:
204203
num_reqs = common_attn_metadata.num_reqs
205204
num_actual_tokens = common_attn_metadata.num_actual_tokens
206205
max_query_len = common_attn_metadata.max_query_len
@@ -212,13 +211,16 @@ def build(
212211
block_table_tensor = common_attn_metadata.block_table_tensor
213212
slot_mapping = common_attn_metadata.slot_mapping
214213

214+
# the overhead of the aot schedule is not worth it for spec-decode
215+
aot_schedule = self.aot_schedule and not fast_build
216+
215217
if self.aot_sliding_window is None:
216218
self.aot_sliding_window = (-1, -1)
217219
# For the AOT scheduler we need the sliding window value to be
218220
# constant for all layers to. We have to populate this on the first
219221
# build() call so the layers are constructed (cannot populate)
220222
# in __init__.
221-
if self.aot_schedule:
223+
if aot_schedule:
222224
sliding_window_configs = _get_sliding_window_configs(
223225
self.vllm_config)
224226
if len(sliding_window_configs) == 1:
@@ -227,10 +229,11 @@ def build(
227229
self.aot_sliding_window = sliding_window_config
228230
elif len(sliding_window_configs) > 1:
229231
self.aot_schedule = False
232+
aot_schedule = False
230233

231234
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
232235
max_seq_len, causal):
233-
if self.aot_schedule:
236+
if aot_schedule:
234237
return get_scheduler_metadata(
235238
batch_size=batch_size,
236239
max_seqlen_q=max_query_len,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,10 @@ def _plan(self, num_prefills: int, num_decodes: int,
365365
kv_data_type=attn_metadata.data_type,
366366
)
367367

368-
def build(self, common_prefix_len: int,
369-
common_attn_metadata: CommonAttentionMetadata):
368+
def build(self,
369+
common_prefix_len: int,
370+
common_attn_metadata: CommonAttentionMetadata,
371+
fast_build: bool = False) -> FlashInferMetadata:
370372
num_actual_tokens = common_attn_metadata.num_actual_tokens
371373
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
372374
split_decodes_and_prefills(common_attn_metadata)

vllm/v1/attention/backends/flex_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
274274
self.kv_cache_spec = kv_cache_spec
275275
self.device = device
276276

277-
def build(self, common_prefix_len: int,
278-
common_attn_metadata: CommonAttentionMetadata):
277+
def build(self,
278+
common_prefix_len: int,
279+
common_attn_metadata: CommonAttentionMetadata,
280+
fast_build: bool = False) -> FlexAttentionMetadata:
279281
num_reqs = common_attn_metadata.num_reqs
280282
num_actual_tokens = common_attn_metadata.num_actual_tokens
281283
max_query_len = common_attn_metadata.max_query_len

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ def reorder_batch(self, input_batch: "InputBatch",
128128

129129
return modified_batch
130130

131-
def build(self, common_prefix_len: int,
132-
common_attn_metadata: CommonAttentionMetadata):
131+
def build(self,
132+
common_prefix_len: int,
133+
common_attn_metadata: CommonAttentionMetadata,
134+
fast_build: bool = False) -> Mamba2AttentionMetadata:
133135
num_reqs = common_attn_metadata.num_reqs
134136
query_start_loc = common_attn_metadata.query_start_loc
135137
seq_lens = common_attn_metadata.seq_lens

vllm/v1/attention/backends/mla/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,10 @@ def build_for_cudagraph_capture(
414414

415415
return self.build(0, m)
416416

417-
def build(self, common_prefix_len: int,
418-
common_attn_metadata: CommonAttentionMetadata) -> M:
417+
def build(self,
418+
common_prefix_len: int,
419+
common_attn_metadata: CommonAttentionMetadata,
420+
fast_build: bool = False) -> M:
419421
num_reqs = common_attn_metadata.num_reqs
420422
num_tokens = common_attn_metadata.num_actual_tokens
421423
max_query_len = common_attn_metadata.max_query_len

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,10 @@ def reorder_batch(self, input_batch: "InputBatch",
193193
scheduler_output: "SchedulerOutput") -> bool:
194194
return False
195195

196-
def build(self, common_prefix_len: int,
197-
common_attn_metadata: CommonAttentionMetadata):
196+
def build(self,
197+
common_prefix_len: int,
198+
common_attn_metadata: CommonAttentionMetadata,
199+
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
198200

199201
num_reqs = common_attn_metadata.num_reqs
200202
num_actual_tokens = common_attn_metadata.num_actual_tokens

vllm/v1/attention/backends/triton_attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def build_for_cudagraph_capture(
9292
attn_metadata.seq_lens.fill_(1)
9393
return attn_metadata
9494

95-
def build(
96-
self, common_prefix_len: int,
97-
common_attn_metadata: CommonAttentionMetadata
98-
) -> TritonAttentionMetadata:
95+
def build(self,
96+
common_prefix_len: int,
97+
common_attn_metadata: CommonAttentionMetadata,
98+
fast_build: bool = False) -> TritonAttentionMetadata:
9999
num_reqs = common_attn_metadata.num_reqs
100100
num_actual_tokens = common_attn_metadata.num_actual_tokens
101101
max_query_len = common_attn_metadata.max_query_len

vllm/v1/attention/backends/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,20 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
7575
self.kv_cache_spec = kv_cache_spec
7676

7777
@abstractmethod
78-
def build(self, common_prefix_len: int,
79-
common_attn_metadata: CommonAttentionMetadata) -> M:
78+
def build(self,
79+
common_prefix_len: int,
80+
common_attn_metadata: CommonAttentionMetadata,
81+
fast_build: bool = False) -> M:
8082
"""
8183
Central method that builds attention metadata.
8284
Some builders (MLA) require reorder_batch to be called prior to build.
85+
86+
Args:
87+
common_prefix_len: The length of the common prefix of the batch.
88+
common_attn_metadata: The common attention metadata.
89+
fast_build: The meta-data will prioritize speed of building over
90+
then speed at execution. Can be used for spec-decode where the
91+
result of a build call may only be used for few layers/iters.
8392
"""
8493
raise NotImplementedError
8594

vllm/v1/spec_decode/eagle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def propose(
111111
attn_metadata = self.runner.attn_metadata_builders[0].build(
112112
common_prefix_len=0,
113113
common_attn_metadata=common_attn_metadata,
114+
fast_build=True,
114115
)
115116

116117
# At this moment, we assume all eagle layers belong to the same KV

0 commit comments

Comments
 (0)