Skip to content

Commit 783e3ee

Browse files
committed
move benchmarks to dedicated workflow
1 parent eade378 commit 783e3ee

File tree

5 files changed

+152
-131
lines changed

5 files changed

+152
-131
lines changed

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ jobs:
8080
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
8181
python3 -m pip install "pybind11[global]"
8282
python3.10 -m pip install git+https://github.com/pytorch/tensordict
83+
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
8384
python3.10 setup.py develop
8485
8586
# test import

.github/workflows/benchmarks_pr.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ jobs:
8282
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
8383
python3.10 -m pip install "pybind11[global]"
8484
python3.10 -m pip install git+https://github.com/pytorch/tensordict
85+
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
8586
python3.10 setup.py develop
8687
# python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH
8788

benchmarks/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
pytest-benchmark
22
tenacity
3+
safetensors
4+
tqdm
5+
pandas
6+
numpy
7+
matplotlib
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import io
6+
import pickle
7+
8+
import pytest
9+
import torch
10+
try:
11+
from safetensors.torch import save
12+
except ImportError:
13+
save = None
14+
15+
from torchrl.data import CompressedListStorage
16+
17+
18+
class TestCompressedStorageBenchmark:
19+
"""Benchmark tests for CompressedListStorage."""
20+
21+
@staticmethod
22+
def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
23+
"""Easily compressible data for testing."""
24+
if device is None:
25+
device = torch.device("cpu")
26+
27+
return {
28+
"observations": torch.zeros(
29+
(num_experiences, 4, 84, 84),
30+
dtype=torch.uint8,
31+
device=device,
32+
),
33+
"actions": torch.zeros((num_experiences,), device=device),
34+
"rewards": torch.zeros((num_experiences,), device=device),
35+
"next_observations": torch.zeros(
36+
(num_experiences, 4, 84, 84),
37+
dtype=torch.uint8,
38+
device=device,
39+
),
40+
"terminations": torch.zeros(
41+
(num_experiences,), dtype=torch.bool, device=device
42+
),
43+
"truncations": torch.zeros(
44+
(num_experiences,), dtype=torch.bool, device=device
45+
),
46+
"batch_size": [num_experiences],
47+
}
48+
49+
@staticmethod
50+
def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
51+
"""Uncompressible data for testing."""
52+
if device is None:
53+
device = torch.device("cpu")
54+
return {
55+
"observations": torch.randn(
56+
(num_experiences, 4, 84, 84),
57+
dtype=torch.float32,
58+
device=device,
59+
),
60+
"actions": torch.randint(0, 10, (num_experiences,), device=device),
61+
"rewards": torch.randn(
62+
(num_experiences,), dtype=torch.float32, device=device
63+
),
64+
"next_observations": torch.randn(
65+
(num_experiences, 4, 84, 84),
66+
dtype=torch.float32,
67+
device=device,
68+
),
69+
"terminations": torch.rand((num_experiences,), device=device)
70+
< 0.2, # ~20% True
71+
"truncations": torch.rand((num_experiences,), device=device)
72+
< 0.1, # ~10% True
73+
"batch_size": [num_experiences],
74+
}
75+
76+
@pytest.mark.benchmark(
77+
group="tensor_serialization_speed",
78+
min_time=0.1,
79+
max_time=0.5,
80+
min_rounds=5,
81+
disable_gc=True,
82+
warmup=False,
83+
)
84+
@pytest.mark.parametrize(
85+
"serialization_method",
86+
["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
87+
)
88+
def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
89+
"""Benchmark the speed of different tensor serialization methods.
90+
91+
TODO: we might need to also test which methods work on the gpu.
92+
pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
93+
94+
------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
95+
Name (time in us) Mean (smaller is better) OPS (bigger is better)
96+
--------------------------------------------------------------------------------------------------
97+
test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
98+
test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
99+
test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
100+
test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
101+
test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
102+
--------------------------------------------------------------------------------------------------
103+
"""
104+
105+
def serialize_with_pickle(data: torch.Tensor) -> bytes:
106+
"""Serialize tensor using pickle."""
107+
buffer = io.BytesIO()
108+
pickle.dump(data, buffer)
109+
return buffer.getvalue()
110+
111+
def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
112+
"""Serialize tensor using torch's built-in method."""
113+
return bytes(data.untyped_storage())
114+
115+
def serialize_with_numpy(data: torch.Tensor) -> bytes:
116+
"""Serialize tensor using numpy."""
117+
return data.numpy().tobytes()
118+
119+
def serialize_with_safetensors(data: torch.Tensor) -> bytes:
120+
return save({"0": data})
121+
122+
def serialize_with_torch(data: torch.Tensor) -> bytes:
123+
"""Serialize tensor using torch's built-in method."""
124+
buffer = io.BytesIO()
125+
torch.save(data, buffer)
126+
return buffer.getvalue()
127+
128+
# Benchmark each serialization method
129+
if serialization_method == "pickle":
130+
serialize_fn = serialize_with_pickle
131+
elif serialization_method == "torch.save":
132+
serialize_fn = serialize_with_torch
133+
elif serialization_method == "untyped_storage":
134+
serialize_fn = serialize_with_untyped_storage
135+
elif serialization_method == "numpy":
136+
serialize_fn = serialize_with_numpy
137+
elif serialization_method == "safetensors":
138+
serialize_fn = serialize_with_safetensors
139+
else:
140+
raise ValueError(f"Unknown serialization method: {serialization_method}")
141+
142+
data = self.make_compressible_mock_data(1).get("observations")
143+
144+
# Run the actual benchmark
145+
benchmark(serialize_fn, data)

test/test_rb.py

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -4343,137 +4343,6 @@ def test_compressed_storage_memory_efficiency(self):
43434343
compression_ratio > 1.5
43444344
), f"Compression ratio {compression_ratio} is too low"
43454345

4346-
@staticmethod
4347-
def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
4348-
"""Easily compressible data for testing."""
4349-
if device is None:
4350-
device = torch.device("cpu")
4351-
4352-
return {
4353-
"observations": torch.zeros(
4354-
(num_experiences, 4, 84, 84),
4355-
dtype=torch.uint8,
4356-
device=device,
4357-
),
4358-
"actions": torch.zeros((num_experiences,), device=device),
4359-
"rewards": torch.zeros((num_experiences,), device=device),
4360-
"next_observations": torch.zeros(
4361-
(num_experiences, 4, 84, 84),
4362-
dtype=torch.uint8,
4363-
device=device,
4364-
),
4365-
"terminations": torch.zeros(
4366-
(num_experiences,), dtype=torch.bool, device=device
4367-
),
4368-
"truncations": torch.zeros(
4369-
(num_experiences,), dtype=torch.bool, device=device
4370-
),
4371-
"batch_size": [num_experiences],
4372-
}
4373-
4374-
@staticmethod
4375-
def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
4376-
"""Uncompressible data for testing."""
4377-
if device is None:
4378-
device = torch.device("cpu")
4379-
return {
4380-
"observations": torch.randn(
4381-
(num_experiences, 4, 84, 84),
4382-
dtype=torch.float32,
4383-
device=device,
4384-
),
4385-
"actions": torch.randint(0, 10, (num_experiences,), device=device),
4386-
"rewards": torch.randn(
4387-
(num_experiences,), dtype=torch.float32, device=device
4388-
),
4389-
"next_observations": torch.randn(
4390-
(num_experiences, 4, 84, 84),
4391-
dtype=torch.float32,
4392-
device=device,
4393-
),
4394-
"terminations": torch.rand((num_experiences,), device=device)
4395-
< 0.2, # ~20% True
4396-
"truncations": torch.rand((num_experiences,), device=device)
4397-
< 0.1, # ~10% True
4398-
"batch_size": [num_experiences],
4399-
}
4400-
4401-
@pytest.mark.benchmark(
4402-
group="tensor_serialization_speed",
4403-
min_time=0.1,
4404-
max_time=0.5,
4405-
min_rounds=5,
4406-
disable_gc=True,
4407-
warmup=False,
4408-
)
4409-
@pytest.mark.parametrize(
4410-
"serialization_method",
4411-
["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
4412-
)
4413-
def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
4414-
"""Benchmark the speed of different tensor serialization methods.
4415-
4416-
TODO: we might need to also test which methods work on the gpu.
4417-
pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
4418-
4419-
------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
4420-
Name (time in us) Mean (smaller is better) OPS (bigger is better)
4421-
--------------------------------------------------------------------------------------------------
4422-
test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
4423-
test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
4424-
test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
4425-
test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
4426-
test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
4427-
--------------------------------------------------------------------------------------------------
4428-
"""
4429-
import io
4430-
import pickle
4431-
4432-
import torch
4433-
from safetensors.torch import save
4434-
4435-
def serialize_with_pickle(data: torch.Tensor) -> bytes:
4436-
"""Serialize tensor using pickle."""
4437-
buffer = io.BytesIO()
4438-
pickle.dump(data, buffer)
4439-
return buffer.getvalue()
4440-
4441-
def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
4442-
"""Serialize tensor using torch's built-in method."""
4443-
return bytes(data.untyped_storage())
4444-
4445-
def serialize_with_numpy(data: torch.Tensor) -> bytes:
4446-
"""Serialize tensor using numpy."""
4447-
return data.numpy().tobytes()
4448-
4449-
def serialize_with_safetensors(data: torch.Tensor) -> bytes:
4450-
return save({"0": data})
4451-
4452-
def serialize_with_torch(data: torch.Tensor) -> bytes:
4453-
"""Serialize tensor using torch's built-in method."""
4454-
buffer = io.BytesIO()
4455-
torch.save(data, buffer)
4456-
return buffer.getvalue()
4457-
4458-
# Benchmark each serialization method
4459-
if serialization_method == "pickle":
4460-
serialize_fn = serialize_with_pickle
4461-
elif serialization_method == "torch.save":
4462-
serialize_fn = serialize_with_torch
4463-
elif serialization_method == "untyped_storage":
4464-
serialize_fn = serialize_with_untyped_storage
4465-
elif serialization_method == "numpy":
4466-
serialize_fn = serialize_with_numpy
4467-
elif serialization_method == "safetensors":
4468-
serialize_fn = serialize_with_safetensors
4469-
else:
4470-
raise ValueError(f"Unknown serialization method: {serialization_method}")
4471-
4472-
data = self.make_compressible_mock_data(1).get("observations")
4473-
4474-
# Run the actual benchmark
4475-
benchmark(serialize_fn, data)
4476-
44774346

44784347
if __name__ == "__main__":
44794348
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)