Skip to content

Commit 2bd0642

Browse files
AdrianOrensteinvmoensAdrian Orenstein
authored
[Feature] Compressed storage gpu (#3062)
Co-authored-by: vmoens <vincentmoens@gmail.com> Co-authored-by: Adrian Orenstein <adrianorenstein@gmail.com>
1 parent 0627e85 commit 2bd0642

File tree

14 files changed

+1468
-65
lines changed

14 files changed

+1468
-65
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ dependencies:
3535
- transformers
3636
- ninja
3737
- timm
38+
- safetensors

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ jobs:
8080
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
8181
python3 -m pip install "pybind11[global]"
8282
python3.10 -m pip install git+https://github.com/pytorch/tensordict
83+
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
8384
python3.10 setup.py develop
8485
8586
# test import

.github/workflows/benchmarks_pr.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ jobs:
8282
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
8383
python3.10 -m pip install "pybind11[global]"
8484
python3.10 -m pip install git+https://github.com/pytorch/tensordict
85+
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
8586
python3.10 setup.py develop
8687
# python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH
8788

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
- libcst == 0.4.7
2121

2222
- repo: https://github.com/pycqa/flake8
23-
rev: 4.0.1
23+
rev: 6.0.0
2424
hooks:
2525
- id: flake8
2626
args: [--config=setup.cfg]

benchmarks/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
pytest-benchmark
22
tenacity
3+
safetensors
4+
tqdm
5+
pandas
6+
numpy
7+
matplotlib
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import io
6+
import pickle
7+
8+
import pytest
9+
import torch
10+
try:
11+
from safetensors.torch import save
12+
except ImportError:
13+
save = None
14+
15+
from torchrl.data import CompressedListStorage
16+
17+
18+
class TestCompressedStorageBenchmark:
19+
"""Benchmark tests for CompressedListStorage."""
20+
21+
@staticmethod
22+
def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
23+
"""Easily compressible data for testing."""
24+
if device is None:
25+
device = torch.device("cpu")
26+
27+
return {
28+
"observations": torch.zeros(
29+
(num_experiences, 4, 84, 84),
30+
dtype=torch.uint8,
31+
device=device,
32+
),
33+
"actions": torch.zeros((num_experiences,), device=device),
34+
"rewards": torch.zeros((num_experiences,), device=device),
35+
"next_observations": torch.zeros(
36+
(num_experiences, 4, 84, 84),
37+
dtype=torch.uint8,
38+
device=device,
39+
),
40+
"terminations": torch.zeros(
41+
(num_experiences,), dtype=torch.bool, device=device
42+
),
43+
"truncations": torch.zeros(
44+
(num_experiences,), dtype=torch.bool, device=device
45+
),
46+
"batch_size": [num_experiences],
47+
}
48+
49+
@staticmethod
50+
def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
51+
"""Uncompressible data for testing."""
52+
if device is None:
53+
device = torch.device("cpu")
54+
return {
55+
"observations": torch.randn(
56+
(num_experiences, 4, 84, 84),
57+
dtype=torch.float32,
58+
device=device,
59+
),
60+
"actions": torch.randint(0, 10, (num_experiences,), device=device),
61+
"rewards": torch.randn(
62+
(num_experiences,), dtype=torch.float32, device=device
63+
),
64+
"next_observations": torch.randn(
65+
(num_experiences, 4, 84, 84),
66+
dtype=torch.float32,
67+
device=device,
68+
),
69+
"terminations": torch.rand((num_experiences,), device=device)
70+
< 0.2, # ~20% True
71+
"truncations": torch.rand((num_experiences,), device=device)
72+
< 0.1, # ~10% True
73+
"batch_size": [num_experiences],
74+
}
75+
76+
@pytest.mark.benchmark(
77+
group="tensor_serialization_speed",
78+
min_time=0.1,
79+
max_time=0.5,
80+
min_rounds=5,
81+
disable_gc=True,
82+
warmup=False,
83+
)
84+
@pytest.mark.parametrize(
85+
"serialization_method",
86+
["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
87+
)
88+
def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
89+
"""Benchmark the speed of different tensor serialization methods.
90+
91+
TODO: we might need to also test which methods work on the gpu.
92+
pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
93+
94+
------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
95+
Name (time in us) Mean (smaller is better) OPS (bigger is better)
96+
--------------------------------------------------------------------------------------------------
97+
test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
98+
test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
99+
test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
100+
test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
101+
test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
102+
--------------------------------------------------------------------------------------------------
103+
"""
104+
105+
def serialize_with_pickle(data: torch.Tensor) -> bytes:
106+
"""Serialize tensor using pickle."""
107+
buffer = io.BytesIO()
108+
pickle.dump(data, buffer)
109+
return buffer.getvalue()
110+
111+
def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
112+
"""Serialize tensor using torch's built-in method."""
113+
return bytes(data.untyped_storage())
114+
115+
def serialize_with_numpy(data: torch.Tensor) -> bytes:
116+
"""Serialize tensor using numpy."""
117+
return data.numpy().tobytes()
118+
119+
def serialize_with_safetensors(data: torch.Tensor) -> bytes:
120+
return save({"0": data})
121+
122+
def serialize_with_torch(data: torch.Tensor) -> bytes:
123+
"""Serialize tensor using torch's built-in method."""
124+
buffer = io.BytesIO()
125+
torch.save(data, buffer)
126+
return buffer.getvalue()
127+
128+
# Benchmark each serialization method
129+
if serialization_method == "pickle":
130+
serialize_fn = serialize_with_pickle
131+
elif serialization_method == "torch.save":
132+
serialize_fn = serialize_with_torch
133+
elif serialization_method == "untyped_storage":
134+
serialize_fn = serialize_with_untyped_storage
135+
elif serialization_method == "numpy":
136+
serialize_fn = serialize_with_numpy
137+
elif serialization_method == "safetensors":
138+
serialize_fn = serialize_with_safetensors
139+
else:
140+
raise ValueError(f"Unknown serialization method: {serialization_method}")
141+
142+
data = self.make_compressible_mock_data(1).get("observations")
143+
144+
# Run the actual benchmark
145+
benchmark(serialize_fn, data)

docs/source/reference/data.rst

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ using the following components:
144144
:template: rl_template.rst
145145

146146

147+
CompressedListStorage
148+
CompressedListStorageCheckpointer
147149
FlatStorageCheckpointer
148150
H5StorageCheckpointer
149151
ImmutableDatasetWriter
@@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be
191193
| :class:`LazyMemmapStorage` | 3.44x |
192194
+-------------------------------+-----------+
193195

196+
Compressed Storage for Memory Efficiency
197+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
198+
199+
For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring
200+
large sensory observations like images, audio, or text. The :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage`
201+
provides significant memory savings through compression.
202+
203+
The `CompressedListStorage`` compresses data when storing and decompresses when retrieving,
204+
achieving compression ratios of 2-10x for image data while maintaining full data fidelity.
205+
It uses zstd compression by default but supports custom compression algorithms.
206+
207+
Key features:
208+
- **Memory Efficiency**: Achieves significant memory savings through compression
209+
- **Data Integrity**: Maintains full data fidelity through lossless compression
210+
- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default
211+
- **TensorDict Support**: Seamlessly works with TensorDict structures
212+
- **Checkpointing**: Full support for saving and loading compressed data
213+
214+
Example usage:
215+
216+
>>> import torch
217+
>>> from torchrl.data import ReplayBuffer, CompressedListStorage
218+
>>> from tensordict import TensorDict
219+
>>>
220+
>>> # Create a compressed storage for image data
221+
>>> storage = CompressedListStorage(max_size=1000, compression_level=3)
222+
>>> rb = ReplayBuffer(storage=storage, batch_size=32)
223+
>>>
224+
>>> # Add image data
225+
>>> images = torch.randn(100, 3, 84, 84) # Atari-like frames
226+
>>> data = TensorDict({"obs": images}, batch_size=[100])
227+
>>> rb.extend(data)
228+
>>>
229+
>>> # Sample data (automatically decompressed)
230+
>>> sample = rb.sample(16)
231+
>>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84])
232+
233+
The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression),
234+
with level 3 being a good default for most use cases.
235+
236+
For custom compression algorithms:
237+
238+
>>> def my_compress(tensor):
239+
... return tensor.to(torch.uint8) # Simple example
240+
>>>
241+
>>> def my_decompress(compressed_tensor, metadata):
242+
... return compressed_tensor.to(metadata["dtype"])
243+
>>>
244+
>>> storage = CompressedListStorage(
245+
... max_size=1000,
246+
... compression_fn=my_compress,
247+
... decompression_fn=my_decompress
248+
... )
249+
250+
.. note:: The CompressedListStorage requires the `zstandard` library for default compression.
251+
Install with: ``pip install zstandard``
252+
253+
.. note:: An example of how to use the CompressedListStorage is available in the
254+
`examples/replay-buffers/compressed_replay_buffer_example.py <https://github.com/pytorch/rl/blob/main/examples/replay-buffers/compressed_replay_buffer_example.py>`_ file.
255+
194256
Sharing replay buffers across processes
195257
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196258

0 commit comments

Comments
 (0)