1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import io
6
+ import pickle
7
+
8
+ import pytest
9
+ import torch
10
+ try :
11
+ from safetensors .torch import save
12
+ except ImportError :
13
+ save = None
14
+
15
+ from torchrl .data import CompressedListStorage
16
+
17
+
18
+ class TestCompressedStorageBenchmark :
19
+ """Benchmark tests for CompressedListStorage."""
20
+
21
+ @staticmethod
22
+ def make_compressible_mock_data (num_experiences : int , device = None ) -> dict :
23
+ """Easily compressible data for testing."""
24
+ if device is None :
25
+ device = torch .device ("cpu" )
26
+
27
+ return {
28
+ "observations" : torch .zeros (
29
+ (num_experiences , 4 , 84 , 84 ),
30
+ dtype = torch .uint8 ,
31
+ device = device ,
32
+ ),
33
+ "actions" : torch .zeros ((num_experiences ,), device = device ),
34
+ "rewards" : torch .zeros ((num_experiences ,), device = device ),
35
+ "next_observations" : torch .zeros (
36
+ (num_experiences , 4 , 84 , 84 ),
37
+ dtype = torch .uint8 ,
38
+ device = device ,
39
+ ),
40
+ "terminations" : torch .zeros (
41
+ (num_experiences ,), dtype = torch .bool , device = device
42
+ ),
43
+ "truncations" : torch .zeros (
44
+ (num_experiences ,), dtype = torch .bool , device = device
45
+ ),
46
+ "batch_size" : [num_experiences ],
47
+ }
48
+
49
+ @staticmethod
50
+ def make_uncompressible_mock_data (num_experiences : int , device = None ) -> dict :
51
+ """Uncompressible data for testing."""
52
+ if device is None :
53
+ device = torch .device ("cpu" )
54
+ return {
55
+ "observations" : torch .randn (
56
+ (num_experiences , 4 , 84 , 84 ),
57
+ dtype = torch .float32 ,
58
+ device = device ,
59
+ ),
60
+ "actions" : torch .randint (0 , 10 , (num_experiences ,), device = device ),
61
+ "rewards" : torch .randn (
62
+ (num_experiences ,), dtype = torch .float32 , device = device
63
+ ),
64
+ "next_observations" : torch .randn (
65
+ (num_experiences , 4 , 84 , 84 ),
66
+ dtype = torch .float32 ,
67
+ device = device ,
68
+ ),
69
+ "terminations" : torch .rand ((num_experiences ,), device = device )
70
+ < 0.2 , # ~20% True
71
+ "truncations" : torch .rand ((num_experiences ,), device = device )
72
+ < 0.1 , # ~10% True
73
+ "batch_size" : [num_experiences ],
74
+ }
75
+
76
+ @pytest .mark .benchmark (
77
+ group = "tensor_serialization_speed" ,
78
+ min_time = 0.1 ,
79
+ max_time = 0.5 ,
80
+ min_rounds = 5 ,
81
+ disable_gc = True ,
82
+ warmup = False ,
83
+ )
84
+ @pytest .mark .parametrize (
85
+ "serialization_method" ,
86
+ ["pickle" , "torch.save" , "untyped_storage" , "numpy" , "safetensors" ],
87
+ )
88
+ def test_tensor_to_bytestream_speed (self , benchmark , serialization_method : str ):
89
+ """Benchmark the speed of different tensor serialization methods.
90
+
91
+ TODO: we might need to also test which methods work on the gpu.
92
+ pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
93
+
94
+ ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
95
+ Name (time in us) Mean (smaller is better) OPS (bigger is better)
96
+ --------------------------------------------------------------------------------------------------
97
+ test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
98
+ test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
99
+ test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
100
+ test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
101
+ test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
102
+ --------------------------------------------------------------------------------------------------
103
+ """
104
+
105
+ def serialize_with_pickle (data : torch .Tensor ) -> bytes :
106
+ """Serialize tensor using pickle."""
107
+ buffer = io .BytesIO ()
108
+ pickle .dump (data , buffer )
109
+ return buffer .getvalue ()
110
+
111
+ def serialize_with_untyped_storage (data : torch .Tensor ) -> bytes :
112
+ """Serialize tensor using torch's built-in method."""
113
+ return bytes (data .untyped_storage ())
114
+
115
+ def serialize_with_numpy (data : torch .Tensor ) -> bytes :
116
+ """Serialize tensor using numpy."""
117
+ return data .numpy ().tobytes ()
118
+
119
+ def serialize_with_safetensors (data : torch .Tensor ) -> bytes :
120
+ return save ({"0" : data })
121
+
122
+ def serialize_with_torch (data : torch .Tensor ) -> bytes :
123
+ """Serialize tensor using torch's built-in method."""
124
+ buffer = io .BytesIO ()
125
+ torch .save (data , buffer )
126
+ return buffer .getvalue ()
127
+
128
+ # Benchmark each serialization method
129
+ if serialization_method == "pickle" :
130
+ serialize_fn = serialize_with_pickle
131
+ elif serialization_method == "torch.save" :
132
+ serialize_fn = serialize_with_torch
133
+ elif serialization_method == "untyped_storage" :
134
+ serialize_fn = serialize_with_untyped_storage
135
+ elif serialization_method == "numpy" :
136
+ serialize_fn = serialize_with_numpy
137
+ elif serialization_method == "safetensors" :
138
+ serialize_fn = serialize_with_safetensors
139
+ else :
140
+ raise ValueError (f"Unknown serialization method: { serialization_method } " )
141
+
142
+ data = self .make_compressible_mock_data (1 ).get ("observations" )
143
+
144
+ # Run the actual benchmark
145
+ benchmark (serialize_fn , data )
0 commit comments