Skip to content

Commit f51448f

Browse files
author
Adam Lugowski
committed
[Bugfix]: Fix Promethus spec decode counter sum-of-sums
The Prometheus spec decode counters (draft/accepted/emitted token counts) are incremented by the values in spec_decode_metrics. However, those values are aggregates since startup. Therefore, the Prometheus counters are effectively a sum-of-sums instead of just a sum. If a high-traffic vLLM is left on for a few hours those counters start to suggest absurdly high values like a TPS in the tens of millions. Signed-off-by: Adam Lugowski <adam.lugowski@parasail.io>
1 parent 3eb08ed commit f51448f

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

vllm/engine/metrics.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from copy import copy
34
import time
45
from typing import TYPE_CHECKING
56
from typing import Counter as CollectionsCounter
@@ -669,20 +670,33 @@ def log(self, stats: Stats):
669670
if local_interval_elapsed(stats.now, self.last_local_log,
670671
self.local_interval):
671672
if self.spec_decode_metrics is not None:
673+
# The counters in self.spec_decode_metrics are aggregates.
674+
# The Prometheus Counters must be incremented with deltas.
675+
# Keep track of the previously seen value so we can compute deltas.
676+
if self.last_spec_decode_metrics is None:
677+
self.last_spec_decode_metrics = copy(self.spec_decode_metrics)
678+
self.last_spec_decode_metrics.accepted_tokens = 0
679+
self.last_spec_decode_metrics.draft_tokens = 0
680+
self.last_spec_decode_metrics.emitted_tokens = 0
681+
682+
snapshot = copy(self.spec_decode_metrics)
683+
672684
self._log_gauge(
673685
self.metrics.gauge_spec_decode_draft_acceptance_rate,
674686
self.spec_decode_metrics.draft_acceptance_rate)
675687
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
676688
self.spec_decode_metrics.system_efficiency)
677689
self._log_counter(
678690
self.metrics.counter_spec_decode_num_accepted_tokens,
679-
self.spec_decode_metrics.accepted_tokens)
691+
snapshot.accepted_tokens - self.last_spec_decode_metrics.accepted_tokens)
680692
self._log_counter(
681693
self.metrics.counter_spec_decode_num_draft_tokens,
682-
self.spec_decode_metrics.draft_tokens)
694+
snapshot.draft_tokens - self.last_spec_decode_metrics.draft_tokens)
683695
self._log_counter(
684696
self.metrics.counter_spec_decode_num_emitted_tokens,
685-
self.spec_decode_metrics.emitted_tokens)
697+
snapshot.emitted_tokens - self.last_spec_decode_metrics.emitted_tokens)
698+
699+
self.last_spec_decode_metrics = snapshot
686700

687701
# Reset tracked stats for next interval.
688702
self.num_prompt_tokens = []

vllm/engine/metrics_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
8080
self.last_local_log = time.time()
8181
self.local_interval = local_interval
8282
self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
83+
self.last_spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
8384

8485
@abstractmethod
8586
def log(self, stats: Stats) -> None:

0 commit comments

Comments
 (0)