10
10
import numpy as np
11
11
import pytest
12
12
import torch
13
+ from _utils_internal import get_available_devices
13
14
from torchrl .data .tensordict .memmap import MemmapTensor
14
15
15
16
@@ -35,7 +36,18 @@ def test_grad():
35
36
MemmapTensor (t + 1 )
36
37
37
38
38
- @pytest .mark .parametrize ("dtype" , [torch .float , torch .int , torch .double , torch .bool ])
39
+ @pytest .mark .parametrize (
40
+ "dtype" ,
41
+ [
42
+ torch .half ,
43
+ torch .float ,
44
+ torch .double ,
45
+ torch .int ,
46
+ torch .uint8 ,
47
+ torch .long ,
48
+ torch .bool ,
49
+ ],
50
+ )
39
51
@pytest .mark .parametrize (
40
52
"shape" ,
41
53
[
@@ -45,8 +57,9 @@ def test_grad():
45
57
[1 , 2 ],
46
58
],
47
59
)
48
- def test_memmap_metadata (dtype , shape ):
49
- t = torch .tensor ([1 , 0 ]).reshape (shape )
60
+ def test_memmap_data_type (dtype , shape ):
61
+ """Test that MemmapTensor can be created with a given data type and shape."""
62
+ t = torch .tensor ([1 , 0 ], dtype = dtype ).reshape (shape )
50
63
m = MemmapTensor (t )
51
64
assert m .dtype == t .dtype
52
65
assert (m == t ).all ()
@@ -137,9 +150,49 @@ def test_memmap_clone():
137
150
assert m2c == m1
138
151
139
152
140
- def test_memmap_tensor ():
141
- t = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
142
- assert (torch .tensor (t ) == t ).all ()
153
+ @pytest .mark .parametrize ("device" , get_available_devices ())
154
+ def test_memmap_same_device_as_tensor (device ):
155
+ """
156
+ Created MemmapTensor should be on the same device as the input tensor.
157
+ Check if device is correct when .to(device) is called.
158
+ """
159
+ t = torch .tensor ([1 ], device = device )
160
+ m = MemmapTensor (t )
161
+ assert m .device == torch .device (device )
162
+ for other_device in get_available_devices ():
163
+ if other_device != device :
164
+ with pytest .raises (
165
+ RuntimeError ,
166
+ match = "Expected all tensors to be on the same device, "
167
+ + "but found at least two devices" ,
168
+ ):
169
+ assert torch .all (m + torch .ones ([3 , 4 ], device = other_device ) == 1 )
170
+ m = m .to (other_device )
171
+ assert m .device == torch .device (other_device )
172
+
173
+
174
+ @pytest .mark .parametrize ("device" , get_available_devices ())
175
+ def test_memmap_create_on_same_device (device ):
176
+ """Test if the device arg for MemmapTensor init is respected."""
177
+ m = MemmapTensor ([3 , 4 ], device = device )
178
+ assert m .device == torch .device (device )
179
+
180
+
181
+ @pytest .mark .parametrize ("device" , get_available_devices ())
182
+ @pytest .mark .parametrize (
183
+ "value" , [torch .zeros ([3 , 4 ]), MemmapTensor (torch .zeros ([3 , 4 ]))]
184
+ )
185
+ @pytest .mark .parametrize ("shape" , [[3 , 4 ], [[3 , 4 ]]])
186
+ def test_memmap_zero_value (device , value , shape ):
187
+ """
188
+ Test if all entries are zeros when MemmapTensor is created with size.
189
+ """
190
+ value = value .to (device )
191
+ expected_memmap_tensor = MemmapTensor (value )
192
+ m = MemmapTensor (* shape , device = device )
193
+ assert m .shape == (3 , 4 )
194
+ assert torch .all (m == expected_memmap_tensor )
195
+ assert torch .all (m + torch .ones ([3 , 4 ], device = device ) == 1 )
143
196
144
197
145
198
if __name__ == "__main__" :
0 commit comments