Skip to content

Commit 1de59d5

Browse files
committed
add docstrings
1 parent e627f0a commit 1de59d5

File tree

4 files changed

+43
-25
lines changed

4 files changed

+43
-25
lines changed

requirements/test.txt

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ argcomplete==3.5.1
3131
# via datamodel-code-generator
3232
arrow==1.3.0
3333
# via isoduration
34-
async-timeout==5.0.1
35-
# via
36-
# aiohttp
37-
# redis
3834
attrs==24.2.0
3935
# via
4036
# aiohttp
@@ -145,11 +141,6 @@ eval-type-backport==0.2.2
145141
# via mteb
146142
evaluate==0.4.3
147143
# via lm-eval
148-
exceptiongroup==1.3.0
149-
# via
150-
# anyio
151-
# hypothesis
152-
# pytest
153144
fastparquet==2024.11.0
154145
# via genai-perf
155146
fastrlock==0.8.2
@@ -699,6 +690,7 @@ setuptools==77.0.3
699690
# via
700691
# mamba-ssm
701692
# pytablewriter
693+
# torch
702694
# triton
703695
shellingham==1.5.4
704696
# via typer
@@ -761,13 +753,8 @@ tokenizers==0.21.1
761753
# via
762754
# -r requirements/test.in
763755
# transformers
764-
toml==0.10.2
765-
# via datamodel-code-generator
766756
tomli==2.2.1
767-
# via
768-
# black
769-
# pytest
770-
# schemathesis
757+
# via schemathesis
771758
tomli-w==1.2.0
772759
# via schemathesis
773760
torch==2.7.0+cu128
@@ -841,18 +828,13 @@ types-python-dateutil==2.9.0.20241206
841828
# via arrow
842829
typing-extensions==4.12.2
843830
# via
844-
# anyio
845-
# black
846-
# exceptiongroup
847831
# huggingface-hub
848832
# librosa
849833
# mistral-common
850834
# mteb
851-
# multidict
852835
# pqdm
853836
# pydantic
854837
# pydantic-core
855-
# rich
856838
# torch
857839
# typer
858840
# typing-inspection
@@ -892,4 +874,4 @@ yarl==1.17.1
892874
# aiohttp
893875
# schemathesis
894876
zstandard==0.23.0
895-
# via lm-eval
877+
# via lm-eval

vllm/v1/spec_decode/eagle.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
device=device,
7878
dtype=torch.int32)
7979

80+
# Used to store precomputed values from load_model() so they can be used in propose()
8081
self.last_token_indices = torch.zeros(self.max_num_tokens,
8182
dtype=torch.int32,
8283
device=device)
@@ -224,6 +225,18 @@ def advance_speculative_state(self, draft_token_ids: torch.Tensor,
224225
hidden_states: torch.Tensor,
225226
attn_metadata: FlashAttentionMetadata,
226227
batch_size: int):
228+
"""
229+
Advances the speculative decoding state and metadata by one step
230+
231+
Parameters:
232+
----------
233+
draft_token_ids (torch.Tensor): Token IDs generated by the draft model
234+
positions (torch.Tensor): Position indices for the draft tokens
235+
hidden_states (torch.Tensor): Corresponding hidden states for the tokens
236+
attn_metadata (FlashAttentionMetadata): Metadata required for FlashAttention (e.g., sequence lengths, block table).
237+
batch_size (int): Number of sequences to update.
238+
"""
239+
227240
# Calculate number of thread blocks
228241
grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE']), )
229242
attn_metadata.slot_mapping = torch.empty_like(positions)
@@ -305,8 +318,21 @@ def load_inputs(self, target_token_ids: torch.Tensor,
305318
target_hidden_states: torch.Tensor,
306319
next_token_ids_gpu: torch.Tensor,
307320
cu_num_tokens: torch.Tensor, num_scheduled_tokens: int):
308-
# Loads token ids, positions, etc. into the eagle model
309-
# Moved from EagleProposer.propose() to here
321+
"""
322+
Loads token ids, positions, etc. into the eagle model
323+
324+
Logic moved from EagleProposer.propose() to here
325+
326+
Parameters:
327+
----------
328+
target_token_ids (torch.Tensor): Draft-step token IDs
329+
target_positions (torch.Tensor): Position indices for the tokens
330+
target_hidden_states (torch.Tensor): Corresponding hidden states for the tokens
331+
next_token_ids_gpu (torch.Tensor): Sampled next token IDs to overwrite final token
332+
cu_num_tokens (torch.Tensor): Cumulative number of tokens from prepare_inputs()
333+
num_scheduled_tokens (int): Total number of tokens scheduled
334+
"""
335+
310336
self.last_token_indices = cu_num_tokens[1:] - 1
311337

312338
# FA requires seq_len to have dtype int32.

vllm/v1/spec_decode/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def advance_state_kernel(
7070
BLOCK_SIZE: tl.constexpr,
7171
PADDING_SLOT_ID: tl.constexpr,
7272
):
73+
# Triton kernel to perform draft model state advancement.
74+
7375
pid = tl.program_id(axis=0)
7476
block_start = pid * BLOCK_SIZE
7577
offsets = block_start + tl.arange(0, BLOCK_SIZE)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,8 +1697,16 @@ def execute_model(
16971697
def get_valid_sampled_token_ids(
16981698
self, max_gen_len: int, sampled_token_ids: torch.Tensor,
16991699
discard_sampled_tokens_req_indices: np.ndarray) -> list[list[int]]:
1700-
# Returns valid sampled tokens in a list of lists based on
1701-
# max gen length and discard indices
1700+
"""
1701+
Returns valid sampled tokens in a list of lists based on max gen length and discard indices
1702+
1703+
Parameters:
1704+
----------
1705+
- max_gen_len: Maximum length of the generated tokens
1706+
- sampled_token_ids: Tensor of sampled token IDs
1707+
- discard_sampled_tokens_req_indices: Indices of requests that should not be sampled
1708+
"""
1709+
17021710
if max_gen_len == 1:
17031711
# No spec decode tokens.
17041712
valid_sampled_token_ids = sampled_token_ids.tolist()

0 commit comments

Comments
 (0)