Skip to content

Commit 2e3e3c8

Browse files
Export NaNs in logits to scheduler_stats if output is corrupted (#18777)
Signed-off-by: Vlad Mihailescu <vtmihailescu@gmail.com>
1 parent 7e8977f commit 2e3e3c8

File tree

7 files changed

+104
-2
lines changed

7 files changed

+104
-2
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55

66
import pytest
7+
import torch
78

89
from vllm.attention import Attention
910
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
277278
assert _is_req_state_block_table_match(model_runner, req_id)
278279

279280

281+
def test_get_nans_in_logits(model_runner):
282+
req_ids = ("req_0", "req_1")
283+
284+
scheduler_output = _schedule_new_request(*req_ids)
285+
model_runner._update_states(scheduler_output)
286+
287+
logits = torch.tensor([
288+
[1.0, 2.0, 3.0],
289+
[3.0, 2.0, 1.0],
290+
], device=DEVICE)
291+
result = model_runner._get_nans_in_logits(logits)
292+
assert result == {"req_0": 0, "req_1": 0}
293+
294+
logits = torch.tensor([
295+
[1.0, float('nan'), 3.0],
296+
[4.0, float('nan'), float('nan')],
297+
],
298+
device=DEVICE)
299+
result = model_runner._get_nans_in_logits(logits)
300+
assert result == {"req_0": 1, "req_1": 2}
301+
302+
logits = torch.tensor([
303+
[1.0, 2.0, 3.0],
304+
[4.0, float('nan'), float('nan')],
305+
],
306+
device=DEVICE)
307+
result = model_runner._get_nans_in_logits(logits)
308+
assert result == {"req_0": 0, "req_1": 2}
309+
310+
result = model_runner._get_nans_in_logits(logits=None)
311+
assert result == {"req_0": 0, "req_1": 0}
312+
313+
logits = torch.tensor([
314+
[1.0, float('nan'), 3.0],
315+
], device=DEVICE)
316+
result = model_runner._get_nans_in_logits(logits)
317+
assert result == {'req_0': 1, 'req_1': 0}
318+
319+
logits = torch.tensor([
320+
[float('nan'), float('nan'), 2.0],
321+
[1.0, 2.0, 3.0],
322+
[float('nan'), 2.0, 3.0],
323+
],
324+
device=DEVICE)
325+
result = model_runner._get_nans_in_logits(logits)
326+
assert result == {'req_0': 2, 'req_1': 0}
327+
328+
280329
def test_update_states_no_changes(model_runner):
281330
req_id = "req_0"
282331

vllm/envs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
VLLM_SLEEP_WHEN_IDLE: bool = False
131131
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
132132
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
133+
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
133134

134135

135136
def get_default_cache_root():
@@ -897,7 +898,13 @@ def get_vllm_port() -> Optional[int]:
897898
# leave the layout choice to the backend. Mind that backends may only
898899
# implement and support a subset of all possible layouts.
899900
"VLLM_KV_CACHE_LAYOUT":
900-
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
901+
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None),
902+
903+
# Enable checking whether the generated logits contain NaNs,
904+
# indicating corrupted output. Useful for debugging low level bugs
905+
# or bad hardware but it may add compute overhead.
906+
"VLLM_COMPUTE_NANS_IN_LOGITS":
907+
lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))),
901908
}
902909

903910
# --8<-- [end:env-vars-definition]

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def update_from_output(
717717
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
718718
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
719719
pooler_outputs = model_runner_output.pooler_output
720+
num_nans_in_logits = model_runner_output.num_nans_in_logits
720721

721722
new_running: list[Request] = []
722723
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
@@ -810,6 +811,10 @@ def update_from_output(
810811
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
811812
req_id, new_token_ids)
812813

814+
# spec_token_ids comes from the model runner output
815+
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
816+
request.num_nans_in_logits = num_nans_in_logits[req_id]
817+
813818
# Add newly generated spec token ids to the request.
814819
if spec_token_ids is not None:
815820
if self.structured_output_manager.should_advance(request):
@@ -972,6 +977,8 @@ def make_stats(
972977
kv_cache_usage=self.kv_cache_manager.usage,
973978
prefix_cache_stats=prefix_cache_stats,
974979
spec_decoding_stats=spec_decoding_stats,
980+
num_corrupted_reqs=sum(req.is_output_corrupted
981+
for req in self.running),
975982
)
976983

977984
def make_spec_decoding_stats(

vllm/v1/metrics/stats.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class SchedulerStats:
4040

4141
spec_decoding_stats: Optional[SpecDecodingStats] = None
4242

43+
num_corrupted_reqs: int = 0
44+
4345

4446
@dataclass
4547
class LoRAStats:

vllm/v1/outputs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class ModelRunnerOutput:
108108
finished_sending: Optional[set[str]] = None
109109
finished_recving: Optional[set[str]] = None
110110

111+
# req_id -> num_nans_in_logits
112+
num_nans_in_logits: Optional[dict[str, int]] = None
113+
111114

112115
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
113116
req_id_to_index={},
@@ -117,4 +120,5 @@ class ModelRunnerOutput:
117120
prompt_logprobs_dict={},
118121
pooler_output=[],
119122
finished_sending=None,
120-
finished_recving=None)
123+
finished_recving=None,
124+
num_nans_in_logits=None)

vllm/v1/request.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def __init__(
9797
# The number of tokens with prefix cache hits.
9898
self.num_cached_tokens = -1
9999

100+
# The number of NaNs in logits. A value greater than 0
101+
# indicates that the output is corrupted
102+
self.num_nans_in_logits = 0
103+
100104
@classmethod
101105
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
102106
if request.mm_inputs is not None:
@@ -132,6 +136,10 @@ def append_output_token_ids(
132136
self._output_token_ids.extend(token_ids)
133137
self._all_token_ids.extend(token_ids)
134138

139+
@property
140+
def is_output_corrupted(self) -> bool:
141+
return self.num_nans_in_logits > 0
142+
135143
@property
136144
def num_tokens(self) -> int:
137145
return len(self._all_token_ids)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,10 @@ def execute_model(
14311431
)
14321432
sampler_output.sampled_token_ids = output_token_ids
14331433

1434+
num_nans_in_logits = {}
1435+
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
1436+
num_nans_in_logits = self._get_nans_in_logits(logits)
1437+
14341438
# TODO(woosuk): The following loop can be slow since it iterates over
14351439
# the requests one by one. Optimize.
14361440
discard_sampled_tokens_req_indices = []
@@ -1601,6 +1605,7 @@ def execute_model(
16011605
pooler_output=[],
16021606
finished_sending=finished_sending,
16031607
finished_recving=finished_recving,
1608+
num_nans_in_logits=num_nans_in_logits,
16041609
)
16051610

16061611
def kv_connector_no_forward(
@@ -1826,6 +1831,26 @@ def _get_prompt_logprobs_dict(
18261831

18271832
return prompt_logprobs_dict
18281833

1834+
def _get_nans_in_logits(
1835+
self,
1836+
logits: Optional[torch.Tensor],
1837+
) -> dict[str, int]:
1838+
try:
1839+
if logits is None:
1840+
return {req_id: 0 for req_id in self.input_batch.req_ids}
1841+
1842+
num_nans_in_logits = {}
1843+
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
1844+
for req_id in self.input_batch.req_ids:
1845+
req_index = self.input_batch.req_id_to_index[req_id]
1846+
num_nans_in_logits[req_id] = (
1847+
int(num_nans_for_index[req_index])
1848+
if num_nans_for_index is not None
1849+
and req_index < logits.shape[0] else 0)
1850+
return num_nans_in_logits
1851+
except IndexError:
1852+
return {}
1853+
18291854
@contextmanager
18301855
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
18311856
"""

0 commit comments

Comments
 (0)