Skip to content

Commit d7601c1

Browse files
authored
[Data] Poll memory usage per map task (#51324)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ### Context Currently, we estimate `BlockExecStats.rss_bytes` by checking the process memory usage after the UDF yields an output. However, at this point, variables might've already gone out of scope and gotten garbage collected. So, the estimate is often a large underestimate. ### Change This PR makes each `_map_task` launch a thread that polls the memory every 1s. This should provide a more accurate memory estimate. You can disable this feature by executing the following statement: ```python ray.data.DataContext.get_current().memory_poll_interval_s = None ``` ### Notes We considered several approaches to estimate the memory use per task: * `ru_maxrss`: High RSS watermark of a process. Inaccurate because Ray might reuse a worker process to execute two distinct UDFs, and RSS double-counts. * `memory_full_info().uss` with polling: Accurate, but slow. In our map release test, can take 0.1s+ per call (for comparison, a PyArrow UDF on a 128 MiB block can take <0.01s). * `memory_info().rss` with polling: Double-counts because it includes shared memory (like Ray objects). * `memory_info().rss - memory_info().shared` with polling. This is an estimate of the USS, and it's the approach we went with. On the map and batch inference release tests, the % difference from true USS is a few percent, except for the model prediction UDF (in this case, it was a 20% underestimate, likely due to counting shared memory from Torch). It takes ~0.0001s per call. Here's a chart of (`true_uss - (rss - shared)`) for the map batches release test. X-axis is value in MiB. ![Distribution Plot with Seaborn](https://github.com/user-attachments/assets/faca1a0c-1f41-4d29-9b4a-4f5bab12817c) ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
1 parent ba72103 commit d7601c1

File tree

8 files changed

+158
-57
lines changed

8 files changed

+158
-57
lines changed

python/ray/data/_internal/execution/interfaces/op_runtime_metrics.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def __init__(self, op: "PhysicalOperator"):
392392
self._per_node_metrics: Dict[str, NodeMetrics] = defaultdict(NodeMetrics)
393393
self._per_node_metrics_enabled: bool = op.data_context.enable_per_node_metrics
394394

395-
self._cum_rss_bytes: Optional[int] = None
395+
self._cum_max_uss_bytes: Optional[int] = None
396396

397397
@property
398398
def extra_metrics(self) -> Dict[str, Any]:
@@ -542,17 +542,17 @@ def average_bytes_outputs_per_task(self) -> Optional[float]:
542542
return self.bytes_outputs_of_finished_tasks / self.num_tasks_finished
543543

544544
@metric_property(
545-
description="Average RSS usage of tasks.",
545+
description="Average USS usage of tasks.",
546546
metrics_group=MetricsGroup.TASKS,
547547
map_only=True,
548548
)
549-
def average_memory_usage_per_task(self) -> Optional[float]:
550-
"""Average RSS usage of tasks."""
551-
if self._cum_rss_bytes is None:
549+
def average_max_uss_per_task(self) -> Optional[float]:
550+
"""Average max USS usage of tasks."""
551+
if self._cum_max_uss_bytes is None:
552552
return None
553553
else:
554554
assert self.num_task_outputs_generated > 0, self.num_task_outputs_generated
555-
return self._cum_rss_bytes / self.num_task_outputs_generated
555+
return self._cum_max_uss_bytes / self.num_task_outputs_generated
556556

557557
def on_input_received(self, input: RefBundle):
558558
"""Callback when the operator receives a new input."""
@@ -639,13 +639,13 @@ def on_task_output_generated(self, task_index: int, output: RefBundle):
639639
assert meta.num_rows is not None
640640
self.rows_task_outputs_generated += meta.num_rows
641641
trace_allocation(block_ref, "operator_output")
642-
if meta.exec_stats.rss_bytes is not None:
643-
if self._cum_rss_bytes is None:
644-
self._cum_rss_bytes = meta.exec_stats.rss_bytes
642+
if meta.exec_stats.max_uss_bytes is not None:
643+
if self._cum_max_uss_bytes is None:
644+
self._cum_max_uss_bytes = meta.exec_stats.max_uss_bytes
645645
else:
646-
self._cum_rss_bytes += meta.exec_stats.rss_bytes
646+
self._cum_max_uss_bytes += meta.exec_stats.max_uss_bytes
647647
else:
648-
assert not self._is_map, "Map operators should collect RSS metrics"
648+
assert not self._is_map, "Map operators should collect memory metrics"
649649

650650
# Update per node metrics
651651
if self._per_node_metrics_enabled:

python/ray/data/_internal/execution/operators/map_operator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949
from ray.data.block import (
5050
Block,
5151
BlockAccessor,
52-
BlockExecStats,
5352
BlockMetadata,
5453
BlockStats,
54+
_BlockExecStatsBuilder,
5555
to_stats,
5656
)
5757
from ray.data.context import DataContext
@@ -530,17 +530,17 @@ def _map_task(
530530
"""
531531
DataContext._set_current(data_context)
532532
ctx.kwargs.update(kwargs)
533-
stats = BlockExecStats.builder()
534533
map_transformer.set_target_max_block_size(ctx.target_max_block_size)
535-
for b_out in map_transformer.apply_transform(iter(blocks), ctx):
536-
# TODO(Clark): Add input file propagation from input blocks.
537-
m_out = BlockAccessor.for_block(b_out).get_metadata()
538-
m_out.exec_stats = stats.build()
539-
m_out.exec_stats.udf_time_s = map_transformer.udf_time()
540-
m_out.exec_stats.task_idx = ctx.task_idx
541-
yield b_out
542-
yield m_out
543-
stats = BlockExecStats.builder()
534+
with _BlockExecStatsBuilder(data_context.memory_poll_interval_s) as stats:
535+
for b_out in map_transformer.apply_transform(iter(blocks), ctx):
536+
# TODO(Clark): Add input file propagation from input blocks.
537+
m_out = BlockAccessor.for_block(b_out).get_metadata()
538+
m_out.exec_stats = stats.build()
539+
m_out.exec_stats.udf_time_s = map_transformer.udf_time()
540+
m_out.exec_stats.task_idx = ctx.task_idx
541+
yield b_out
542+
yield m_out
543+
stats.reset()
544544

545545

546546
class _BlockRefBundler:

python/ray/data/_internal/stats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import enum
23
import logging
34
import threading
45
import time
@@ -7,7 +8,6 @@
78
from dataclasses import dataclass, fields
89
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
910
from uuid import uuid4
10-
import enum
1111

1212
import numpy as np
1313

@@ -1301,7 +1301,7 @@ def from_block_metadata(
13011301
}
13021302

13031303
memory_stats_mb = [
1304-
round(e.rss_bytes / (1024 * 1024), 2) for e in exec_stats
1304+
round(e.max_uss_bytes / (1024 * 1024), 2) for e in exec_stats
13051305
]
13061306
memory_stats = {
13071307
"min": min(memory_stats_mb),

python/ray/data/block.py

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import logging
33
import os
4+
import threading
45
import time
56
from dataclasses import asdict, dataclass, fields
67
from enum import Enum
@@ -134,7 +135,9 @@ def __init__(self):
134135
self.udf_time_s: Optional[float] = 0
135136
self.cpu_time_s: Optional[float] = None
136137
self.node_id = ray.runtime_context.get_runtime_context().get_node_id()
137-
self.rss_bytes: int = 0
138+
# An estimate of the maximum amount of physical memory that the process was
139+
# using while computing this block.
140+
self.max_uss_bytes: int = 0
138141
self.task_idx: Optional[int] = None
139142

140143
@staticmethod
@@ -153,30 +156,99 @@ def __repr__(self):
153156

154157

155158
class _BlockExecStatsBuilder:
156-
"""Helper class for building block stats.
159+
"""Helper context manager for building block stats.
157160
158161
When this class is created, we record the start time. When build() is
159162
called, the time delta is saved as part of the stats.
160163
"""
161164

162-
def __init__(self):
163-
self.start_time = time.perf_counter()
164-
self.start_cpu = time.process_time()
165+
def __init__(self, poll_interval_s: Optional[float] = None):
166+
"""
167+
168+
Args:
169+
poll_interval_s: The interval to poll the USS of the process. If `None`,
170+
this class won't poll the USS.
171+
"""
172+
self._poll_interval_s = poll_interval_s
173+
174+
# Record start times.
175+
self._start_time = time.perf_counter()
176+
self._start_cpu = time.process_time()
177+
178+
# Record initial USS.
179+
self._process = psutil.Process(os.getpid())
180+
self._max_uss = self._estimate_uss()
181+
self._max_uss_lock = threading.Lock()
182+
183+
self._uss_poll_thread = None
184+
self._stop_uss_poll_event = None
185+
186+
def __enter__(self):
187+
if self._poll_interval_s is not None:
188+
(
189+
self._uss_poll_thread,
190+
self._stop_uss_poll_event,
191+
) = self._start_uss_poll_thread()
192+
193+
return self
194+
195+
def __exit__(self, exc_type, exc_val, exc_tb):
196+
if self._uss_poll_thread is not None:
197+
self._stop_uss_poll_thread()
165198

166199
def build(self) -> "BlockExecStats":
167-
self.end_time = time.perf_counter()
168-
self.end_cpu = time.process_time()
200+
# Record end times.
201+
end_time = time.perf_counter()
202+
end_cpu = time.process_time()
203+
204+
# Record max USS.
205+
with self._max_uss_lock:
206+
self._max_uss = max(self._max_uss, self._estimate_uss())
169207

208+
# Build the stats.
170209
stats = BlockExecStats()
171-
stats.start_time_s = self.start_time
172-
stats.end_time_s = self.end_time
173-
stats.wall_time_s = self.end_time - self.start_time
174-
stats.cpu_time_s = self.end_cpu - self.start_cpu
175-
process = psutil.Process(os.getpid())
176-
stats.rss_bytes = int(process.memory_info().rss)
210+
stats.start_time_s = self._start_time
211+
stats.end_time_s = end_time
212+
stats.wall_time_s = end_time - self._start_time
213+
stats.cpu_time_s = end_cpu - self._start_cpu
214+
stats.max_uss_bytes = self._max_uss
177215

178216
return stats
179217

218+
def reset(self):
219+
self._start_time = time.perf_counter()
220+
self._start_cpu = time.process_time()
221+
with self._max_uss_lock:
222+
self._max_uss = self._estimate_uss()
223+
224+
def _start_uss_poll_thread(self) -> Tuple[threading.Thread, threading.Event]:
225+
assert self._poll_interval_s is not None
226+
227+
stop_event = threading.Event()
228+
229+
def poll_uss():
230+
while not stop_event.is_set():
231+
with self._max_uss_lock:
232+
self._max_uss = max(self._max_uss, self._estimate_uss())
233+
stop_event.wait(self._poll_interval_s)
234+
235+
thread = threading.Thread(target=poll_uss, daemon=True)
236+
thread.start()
237+
238+
return thread, stop_event
239+
240+
def _stop_uss_poll_thread(self):
241+
if self._stop_uss_poll_event is not None:
242+
self._stop_uss_poll_event.set()
243+
self._uss_poll_thread.join()
244+
245+
def _estimate_uss(self) -> int:
246+
memory_info = self._process.memory_info()
247+
# Estimate the USS (the amount of memory that'd be free if we killed the
248+
# process right now) as the difference between the RSS (total physical memory)
249+
# and amount of shared physical memory.
250+
return memory_info.rss - memory_info.shared
251+
180252

181253
@DeveloperAPI
182254
@dataclass

python/ray/data/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ class DataContext:
323323
transient errors when reading from remote storage systems.
324324
enable_per_node_metrics: Enable per node metrics reporting for Ray Data,
325325
disabled by default.
326+
memory_poll_interval_s: The interval to poll the USS of map tasks. If `None`,
327+
map tasks won't record memory stats.
326328
"""
327329

328330
target_max_block_size: int = DEFAULT_TARGET_MAX_BLOCK_SIZE
@@ -395,6 +397,7 @@ class DataContext:
395397
)
396398
enable_per_node_metrics: bool = DEFAULT_ENABLE_PER_NODE_METRICS
397399
override_object_store_memory_limit_fraction: float = None
400+
memory_poll_interval_s: Optional[float] = 1
398401

399402
def __post_init__(self):
400403
# The additonal ray remote args that should be added to

python/ray/data/tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def op_two_block():
418418
block_params = {
419419
"num_rows": [10000, 5000],
420420
"size_bytes": [100, 50],
421-
"rss_bytes": [1024 * 1024 * 2, 1024 * 1024 * 1],
421+
"uss_bytes": [1024 * 1024 * 2, 1024 * 1024 * 1],
422422
"wall_time": [5, 10],
423423
"cpu_time": [1.2, 3.4],
424424
"udf_time": [1.1, 1.7],
@@ -439,7 +439,7 @@ def op_two_block():
439439
block_exec_stats.cpu_time_s = block_params["cpu_time"][i]
440440
block_exec_stats.udf_time_s = block_params["udf_time"][i]
441441
block_exec_stats.node_id = block_params["node_id"][i]
442-
block_exec_stats.rss_bytes = block_params["rss_bytes"][i]
442+
block_exec_stats.max_uss_bytes = block_params["uss_bytes"][i]
443443
block_exec_stats.task_idx = block_params["task_idx"][i]
444444
block_meta_list.append(
445445
BlockMetadata(

python/ray/data/tests/test_op_runtime_metrics.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from ray.data.block import BlockExecStats, BlockMetadata
1010

1111

12-
def test_average_memory_usage_per_task():
12+
def test_average_max_uss_per_task():
1313
# No tasks submitted yet.
1414
metrics = OpRuntimeMetrics(MagicMock())
15-
assert metrics.average_memory_usage_per_task is None
15+
assert metrics.average_max_uss_per_task is None
1616

17-
def create_bundle(rss_bytes: int):
17+
def create_bundle(uss_bytes: int):
1818
block = ray.put(pa.Table.from_pydict({}))
1919
stats = BlockExecStats()
20-
stats.rss_bytes = rss_bytes
20+
stats.max_uss_bytes = uss_bytes
2121
stats.wall_time_s = 0
2222
metadata = BlockMetadata(
2323
num_rows=0,
@@ -29,20 +29,20 @@ def create_bundle(rss_bytes: int):
2929
return RefBundle([(block, metadata)], owns_blocks=False)
3030

3131
# Submit two tasks.
32-
bundle = create_bundle(rss_bytes=0)
32+
bundle = create_bundle(uss_bytes=0)
3333
metrics.on_task_submitted(0, bundle)
3434
metrics.on_task_submitted(1, bundle)
35-
assert metrics.average_memory_usage_per_task is None
35+
assert metrics.average_max_uss_per_task is None
3636

3737
# Generate one output for the first task.
38-
bundle = create_bundle(rss_bytes=1)
38+
bundle = create_bundle(uss_bytes=1)
3939
metrics.on_task_output_generated(0, bundle)
40-
assert metrics.average_memory_usage_per_task == 1
40+
assert metrics.average_max_uss_per_task == 1
4141

4242
# Generate one output for the second task.
43-
bundle = create_bundle(rss_bytes=3)
43+
bundle = create_bundle(uss_bytes=3)
4444
metrics.on_task_output_generated(0, bundle)
45-
assert metrics.average_memory_usage_per_task == 2 # (1 + 3) / 2 = 2
45+
assert metrics.average_max_uss_per_task == 2 # (1 + 3) / 2 = 2
4646

4747

4848
if __name__ == "__main__":

0 commit comments

Comments
 (0)