Skip to content

Commit 5833b2f

Browse files
authored
Implement unittests for MemmapTensor. (#231)
1 parent df55871 commit 5833b2f

File tree

1 file changed

+59
-6
lines changed

1 file changed

+59
-6
lines changed

test/test_memmap.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import pytest
1212
import torch
13+
from _utils_internal import get_available_devices
1314
from torchrl.data.tensordict.memmap import MemmapTensor
1415

1516

@@ -35,7 +36,18 @@ def test_grad():
3536
MemmapTensor(t + 1)
3637

3738

38-
@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.double, torch.bool])
39+
@pytest.mark.parametrize(
40+
"dtype",
41+
[
42+
torch.half,
43+
torch.float,
44+
torch.double,
45+
torch.int,
46+
torch.uint8,
47+
torch.long,
48+
torch.bool,
49+
],
50+
)
3951
@pytest.mark.parametrize(
4052
"shape",
4153
[
@@ -45,8 +57,9 @@ def test_grad():
4557
[1, 2],
4658
],
4759
)
48-
def test_memmap_metadata(dtype, shape):
49-
t = torch.tensor([1, 0]).reshape(shape)
60+
def test_memmap_data_type(dtype, shape):
61+
"""Test that MemmapTensor can be created with a given data type and shape."""
62+
t = torch.tensor([1, 0], dtype=dtype).reshape(shape)
5063
m = MemmapTensor(t)
5164
assert m.dtype == t.dtype
5265
assert (m == t).all()
@@ -137,9 +150,49 @@ def test_memmap_clone():
137150
assert m2c == m1
138151

139152

140-
def test_memmap_tensor():
141-
t = torch.tensor([[1, 2, 3], [4, 5, 6]])
142-
assert (torch.tensor(t) == t).all()
153+
@pytest.mark.parametrize("device", get_available_devices())
154+
def test_memmap_same_device_as_tensor(device):
155+
"""
156+
Created MemmapTensor should be on the same device as the input tensor.
157+
Check if device is correct when .to(device) is called.
158+
"""
159+
t = torch.tensor([1], device=device)
160+
m = MemmapTensor(t)
161+
assert m.device == torch.device(device)
162+
for other_device in get_available_devices():
163+
if other_device != device:
164+
with pytest.raises(
165+
RuntimeError,
166+
match="Expected all tensors to be on the same device, "
167+
+ "but found at least two devices",
168+
):
169+
assert torch.all(m + torch.ones([3, 4], device=other_device) == 1)
170+
m = m.to(other_device)
171+
assert m.device == torch.device(other_device)
172+
173+
174+
@pytest.mark.parametrize("device", get_available_devices())
175+
def test_memmap_create_on_same_device(device):
176+
"""Test if the device arg for MemmapTensor init is respected."""
177+
m = MemmapTensor([3, 4], device=device)
178+
assert m.device == torch.device(device)
179+
180+
181+
@pytest.mark.parametrize("device", get_available_devices())
182+
@pytest.mark.parametrize(
183+
"value", [torch.zeros([3, 4]), MemmapTensor(torch.zeros([3, 4]))]
184+
)
185+
@pytest.mark.parametrize("shape", [[3, 4], [[3, 4]]])
186+
def test_memmap_zero_value(device, value, shape):
187+
"""
188+
Test if all entries are zeros when MemmapTensor is created with size.
189+
"""
190+
value = value.to(device)
191+
expected_memmap_tensor = MemmapTensor(value)
192+
m = MemmapTensor(*shape, device=device)
193+
assert m.shape == (3, 4)
194+
assert torch.all(m == expected_memmap_tensor)
195+
assert torch.all(m + torch.ones([3, 4], device=device) == 1)
143196

144197

145198
if __name__ == "__main__":

0 commit comments

Comments
 (0)