Skip to content

Commit 5418176

Browse files
authored
[Misc] Add Ray Prometheus logger to V1 (#17925)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
1 parent 67da572 commit 5418176

File tree

4 files changed

+223
-35
lines changed

4 files changed

+223
-35
lines changed

tests/v1/metrics/test_ray_metrics.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
import ray
4+
5+
from vllm.sampling_params import SamplingParams
6+
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
7+
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
8+
9+
10+
@pytest.fixture(scope="function", autouse=True)
11+
def use_v1_only(monkeypatch):
12+
"""
13+
The change relies on V1 APIs, so set VLLM_USE_V1=1.
14+
"""
15+
monkeypatch.setenv('VLLM_USE_V1', '1')
16+
17+
18+
MODELS = [
19+
"distilbert/distilgpt2",
20+
]
21+
22+
23+
@pytest.mark.parametrize("model", MODELS)
24+
@pytest.mark.parametrize("dtype", ["half"])
25+
@pytest.mark.parametrize("max_tokens", [16])
26+
def test_engine_log_metrics_ray(
27+
example_prompts,
28+
model: str,
29+
dtype: str,
30+
max_tokens: int,
31+
) -> None:
32+
""" Simple smoke test, verifying this can be used without exceptions.
33+
Need to start a Ray cluster in order to verify outputs."""
34+
35+
@ray.remote(num_gpus=1)
36+
class EngineTestActor:
37+
38+
async def run(self):
39+
engine_args = AsyncEngineArgs(
40+
model=model,
41+
dtype=dtype,
42+
disable_log_stats=False,
43+
)
44+
45+
engine = AsyncLLM.from_engine_args(
46+
engine_args, stat_loggers=[RayPrometheusStatLogger])
47+
48+
for i, prompt in enumerate(example_prompts):
49+
engine.generate(
50+
request_id=f"request-id-{i}",
51+
prompt=prompt,
52+
sampling_params=SamplingParams(max_tokens=max_tokens),
53+
)
54+
55+
# Create the actor and call the async method
56+
actor = EngineTestActor.remote() # type: ignore[attr-defined]
57+
ray.get(actor.run.remote())

vllm/v1/metrics/loggers.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def log_engine_initialized(self):
138138

139139

140140
class PrometheusStatLogger(StatLoggerBase):
141+
_gauge_cls = prometheus_client.Gauge
142+
_counter_cls = prometheus_client.Counter
143+
_histogram_cls = prometheus_client.Histogram
144+
_spec_decoding_cls = SpecDecodingProm
141145

142146
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
143147
self._unregister_vllm_metrics()
@@ -156,37 +160,37 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
156160

157161
max_model_len = vllm_config.model_config.max_model_len
158162

159-
self.spec_decoding_prom = SpecDecodingProm(
163+
self.spec_decoding_prom = self._spec_decoding_cls(
160164
vllm_config.speculative_config, labelnames, labelvalues)
161165

162166
#
163167
# Scheduler state
164168
#
165-
self.gauge_scheduler_running = prometheus_client.Gauge(
169+
self.gauge_scheduler_running = self._gauge_cls(
166170
name="vllm:num_requests_running",
167171
documentation="Number of requests in model execution batches.",
168172
labelnames=labelnames).labels(*labelvalues)
169173

170-
self.gauge_scheduler_waiting = prometheus_client.Gauge(
174+
self.gauge_scheduler_waiting = self._gauge_cls(
171175
name="vllm:num_requests_waiting",
172176
documentation="Number of requests waiting to be processed.",
173177
labelnames=labelnames).labels(*labelvalues)
174178

175179
#
176180
# GPU cache
177181
#
178-
self.gauge_gpu_cache_usage = prometheus_client.Gauge(
182+
self.gauge_gpu_cache_usage = self._gauge_cls(
179183
name="vllm:gpu_cache_usage_perc",
180184
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
181185
labelnames=labelnames).labels(*labelvalues)
182186

183-
self.counter_gpu_prefix_cache_queries = prometheus_client.Counter(
187+
self.counter_gpu_prefix_cache_queries = self._counter_cls(
184188
name="vllm:gpu_prefix_cache_queries",
185189
documentation=
186190
"GPU prefix cache queries, in terms of number of queried tokens.",
187191
labelnames=labelnames).labels(*labelvalues)
188192

189-
self.counter_gpu_prefix_cache_hits = prometheus_client.Counter(
193+
self.counter_gpu_prefix_cache_hits = self._counter_cls(
190194
name="vllm:gpu_prefix_cache_hits",
191195
documentation=
192196
"GPU prefix cache hits, in terms of number of cached tokens.",
@@ -195,24 +199,24 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
195199
#
196200
# Counters
197201
#
198-
self.counter_num_preempted_reqs = prometheus_client.Counter(
202+
self.counter_num_preempted_reqs = self._counter_cls(
199203
name="vllm:num_preemptions_total",
200204
documentation="Cumulative number of preemption from the engine.",
201205
labelnames=labelnames).labels(*labelvalues)
202206

203-
self.counter_prompt_tokens = prometheus_client.Counter(
207+
self.counter_prompt_tokens = self._counter_cls(
204208
name="vllm:prompt_tokens_total",
205209
documentation="Number of prefill tokens processed.",
206210
labelnames=labelnames).labels(*labelvalues)
207211

208-
self.counter_generation_tokens = prometheus_client.Counter(
212+
self.counter_generation_tokens = self._counter_cls(
209213
name="vllm:generation_tokens_total",
210214
documentation="Number of generation tokens processed.",
211215
labelnames=labelnames).labels(*labelvalues)
212216

213217
self.counter_request_success: dict[FinishReason,
214218
prometheus_client.Counter] = {}
215-
counter_request_success_base = prometheus_client.Counter(
219+
counter_request_success_base = self._counter_cls(
216220
name="vllm:request_success_total",
217221
documentation="Count of successfully processed requests.",
218222
labelnames=labelnames + ["finished_reason"])
@@ -225,21 +229,21 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
225229
# Histograms of counts
226230
#
227231
self.histogram_num_prompt_tokens_request = \
228-
prometheus_client.Histogram(
232+
self._histogram_cls(
229233
name="vllm:request_prompt_tokens",
230234
documentation="Number of prefill tokens processed.",
231235
buckets=build_1_2_5_buckets(max_model_len),
232236
labelnames=labelnames).labels(*labelvalues)
233237

234238
self.histogram_num_generation_tokens_request = \
235-
prometheus_client.Histogram(
239+
self._histogram_cls(
236240
name="vllm:request_generation_tokens",
237241
documentation="Number of generation tokens processed.",
238242
buckets=build_1_2_5_buckets(max_model_len),
239243
labelnames=labelnames).labels(*labelvalues)
240244

241245
self.histogram_iteration_tokens = \
242-
prometheus_client.Histogram(
246+
self._histogram_cls(
243247
name="vllm:iteration_tokens_total",
244248
documentation="Histogram of number of tokens per engine_step.",
245249
buckets=[
@@ -249,22 +253,22 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
249253
labelnames=labelnames).labels(*labelvalues)
250254

251255
self.histogram_max_num_generation_tokens_request = \
252-
prometheus_client.Histogram(
256+
self._histogram_cls(
253257
name="vllm:request_max_num_generation_tokens",
254258
documentation=
255259
"Histogram of maximum number of requested generation tokens.",
256260
buckets=build_1_2_5_buckets(max_model_len),
257261
labelnames=labelnames).labels(*labelvalues)
258262

259263
self.histogram_n_request = \
260-
prometheus_client.Histogram(
264+
self._histogram_cls(
261265
name="vllm:request_params_n",
262266
documentation="Histogram of the n request parameter.",
263267
buckets=[1, 2, 5, 10, 20],
264268
labelnames=labelnames).labels(*labelvalues)
265269

266270
self.histogram_max_tokens_request = \
267-
prometheus_client.Histogram(
271+
self._histogram_cls(
268272
name="vllm:request_params_max_tokens",
269273
documentation="Histogram of the max_tokens request parameter.",
270274
buckets=build_1_2_5_buckets(max_model_len),
@@ -274,7 +278,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
274278
# Histogram of timing intervals
275279
#
276280
self.histogram_time_to_first_token = \
277-
prometheus_client.Histogram(
281+
self._histogram_cls(
278282
name="vllm:time_to_first_token_seconds",
279283
documentation="Histogram of time to first token in seconds.",
280284
buckets=[
@@ -285,7 +289,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
285289
labelnames=labelnames).labels(*labelvalues)
286290

287291
self.histogram_time_per_output_token = \
288-
prometheus_client.Histogram(
292+
self._histogram_cls(
289293
name="vllm:time_per_output_token_seconds",
290294
documentation="Histogram of time per output token in seconds.",
291295
buckets=[
@@ -299,34 +303,34 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
299303
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
300304
]
301305
self.histogram_e2e_time_request = \
302-
prometheus_client.Histogram(
306+
self._histogram_cls(
303307
name="vllm:e2e_request_latency_seconds",
304308
documentation="Histogram of e2e request latency in seconds.",
305309
buckets=request_latency_buckets,
306310
labelnames=labelnames).labels(*labelvalues)
307311
self.histogram_queue_time_request = \
308-
prometheus_client.Histogram(
312+
self._histogram_cls(
309313
name="vllm:request_queue_time_seconds",
310314
documentation=
311315
"Histogram of time spent in WAITING phase for request.",
312316
buckets=request_latency_buckets,
313317
labelnames=labelnames).labels(*labelvalues)
314318
self.histogram_inference_time_request = \
315-
prometheus_client.Histogram(
319+
self._histogram_cls(
316320
name="vllm:request_inference_time_seconds",
317321
documentation=
318322
"Histogram of time spent in RUNNING phase for request.",
319323
buckets=request_latency_buckets,
320324
labelnames=labelnames).labels(*labelvalues)
321325
self.histogram_prefill_time_request = \
322-
prometheus_client.Histogram(
326+
self._histogram_cls(
323327
name="vllm:request_prefill_time_seconds",
324328
documentation=
325329
"Histogram of time spent in PREFILL phase for request.",
326330
buckets=request_latency_buckets,
327331
labelnames=labelnames).labels(*labelvalues)
328332
self.histogram_decode_time_request = \
329-
prometheus_client.Histogram(
333+
self._histogram_cls(
330334
name="vllm:request_decode_time_seconds",
331335
documentation=
332336
"Histogram of time spent in DECODE phase for request.",
@@ -343,7 +347,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
343347
self.labelname_running_lora_adapters = "running_lora_adapters"
344348
self.max_lora = vllm_config.lora_config.max_loras
345349
self.gauge_lora_info = \
346-
prometheus_client.Gauge(
350+
self._gauge_cls(
347351
name="vllm:lora_requests_info",
348352
documentation="Running stats on lora requests.",
349353
labelnames=[
@@ -365,7 +369,7 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
365369
# Info type metrics are syntactic sugar for a gauge permanently set to 1
366370
# Since prometheus multiprocessing mode does not support Info, emulate
367371
# info here with a gauge.
368-
info_gauge = prometheus_client.Gauge(
372+
info_gauge = self._gauge_cls(
369373
name=name,
370374
documentation=documentation,
371375
labelnames=metrics_info.keys()).labels(**metrics_info)

vllm/v1/metrics/ray_wrappers.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import time
3+
from typing import Optional, Union
4+
5+
from vllm.config import VllmConfig
6+
from vllm.v1.metrics.loggers import PrometheusStatLogger
7+
from vllm.v1.spec_decode.metrics import SpecDecodingProm
8+
9+
try:
10+
from ray.util import metrics as ray_metrics
11+
from ray.util.metrics import Metric
12+
except ImportError:
13+
ray_metrics = None
14+
15+
16+
class RayPrometheusMetric:
17+
18+
def __init__(self):
19+
if ray_metrics is None:
20+
raise ImportError(
21+
"RayPrometheusMetric requires Ray to be installed.")
22+
23+
self.metric: Metric = None
24+
25+
def labels(self, *labels, **labelskwargs):
26+
if labelskwargs:
27+
for k, v in labelskwargs.items():
28+
if not isinstance(v, str):
29+
labelskwargs[k] = str(v)
30+
31+
self.metric.set_default_tags(labelskwargs)
32+
33+
return self
34+
35+
36+
class RayGaugeWrapper(RayPrometheusMetric):
37+
"""Wraps around ray.util.metrics.Gauge to provide same API as
38+
prometheus_client.Gauge"""
39+
40+
def __init__(self,
41+
name: str,
42+
documentation: Optional[str] = "",
43+
labelnames: Optional[list[str]] = None):
44+
labelnames_tuple = tuple(labelnames) if labelnames else None
45+
self.metric = ray_metrics.Gauge(name=name,
46+
description=documentation,
47+
tag_keys=labelnames_tuple)
48+
49+
def set(self, value: Union[int, float]):
50+
return self.metric.set(value)
51+
52+
def set_to_current_time(self):
53+
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
54+
return self.metric.set(time.time())
55+
56+
57+
class RayCounterWrapper(RayPrometheusMetric):
58+
"""Wraps around ray.util.metrics.Counter to provide same API as
59+
prometheus_client.Counter"""
60+
61+
def __init__(self,
62+
name: str,
63+
documentation: Optional[str] = "",
64+
labelnames: Optional[list[str]] = None):
65+
labelnames_tuple = tuple(labelnames) if labelnames else None
66+
self.metric = ray_metrics.Counter(name=name,
67+
description=documentation,
68+
tag_keys=labelnames_tuple)
69+
70+
def inc(self, value: Union[int, float] = 1.0):
71+
if value == 0:
72+
return
73+
return self.metric.inc(value)
74+
75+
76+
class RayHistogramWrapper(RayPrometheusMetric):
77+
"""Wraps around ray.util.metrics.Histogram to provide same API as
78+
prometheus_client.Histogram"""
79+
80+
def __init__(self,
81+
name: str,
82+
documentation: Optional[str] = "",
83+
labelnames: Optional[list[str]] = None,
84+
buckets: Optional[list[float]] = None):
85+
labelnames_tuple = tuple(labelnames) if labelnames else None
86+
boundaries = buckets if buckets else []
87+
self.metric = ray_metrics.Histogram(name=name,
88+
description=documentation,
89+
tag_keys=labelnames_tuple,
90+
boundaries=boundaries)
91+
92+
def observe(self, value: Union[int, float]):
93+
return self.metric.observe(value)
94+
95+
96+
class RaySpecDecodingProm(SpecDecodingProm):
97+
"""
98+
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
99+
Provides the same metrics as SpecDecodingProm but uses Ray's
100+
util.metrics library.
101+
"""
102+
103+
_counter_cls = RayCounterWrapper
104+
105+
106+
class RayPrometheusStatLogger(PrometheusStatLogger):
107+
"""RayPrometheusStatLogger uses Ray metrics instead."""
108+
109+
_gauge_cls = RayGaugeWrapper
110+
_counter_cls = RayCounterWrapper
111+
_histogram_cls = RayHistogramWrapper
112+
_spec_decoding_cls = RaySpecDecodingProm
113+
114+
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
115+
super().__init__(vllm_config, engine_index)
116+
117+
@staticmethod
118+
def _unregister_vllm_metrics():
119+
# No-op on purpose
120+
pass

0 commit comments

Comments
 (0)