Skip to content

Commit 5c74ab7

Browse files
[BugFix] Support list-based boolean masks for TensorDict (#299)
1 parent 8503811 commit 5c74ab7

File tree

4 files changed

+171
-17
lines changed

4 files changed

+171
-17
lines changed

test/test_tensordict.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,19 @@ def test_mask_td(device):
198198
"key2": torch.randn(4, 5, 10, device=device),
199199
}
200200
mask = torch.zeros(4, 5, dtype=torch.bool, device=device).bernoulli_()
201+
mask_list = mask.cpu().numpy().tolist()
201202
td = TensorDict(batch_size=(4, 5), source=d)
203+
202204
td_masked = torch.masked_select(td, mask)
205+
td_masked1 = td[mask_list]
203206
assert len(td_masked.get("key1")) == td_masked.shape[0]
207+
assert len(td_masked1.get("key1")) == td_masked1.shape[0]
208+
209+
mask_list = [False, True, False, True]
210+
211+
td_masked2 = td[mask_list, 0]
212+
torch.testing.assert_allclose(td.get("key1")[mask_list, 0], td_masked2.get("key1"))
213+
torch.testing.assert_allclose(td.get("key2")[mask_list, 0], td_masked2.get("key2"))
204214

205215

206216
@pytest.mark.parametrize("device", get_available_devices())
@@ -782,6 +792,49 @@ def test_masking(self, td_name, device):
782792
assert td_masked.batch_size[0] == mask.sum()
783793
assert td_masked.batch_dims == 1
784794

795+
mask_list = mask.cpu().numpy().tolist()
796+
td_masked3 = td[mask_list]
797+
assert_allclose_td(td_masked3, td_masked2)
798+
assert td_masked3.batch_size[0] == mask.sum()
799+
assert td_masked3.batch_dims == 1
800+
801+
@pytest.mark.parametrize("from_list", [True, False])
802+
def test_masking_set(self, td_name, device, from_list):
803+
def zeros_like(item, n, d):
804+
if isinstance(item, (MemmapTensor, torch.Tensor)):
805+
return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device)
806+
elif isinstance(item, _TensorDict):
807+
batch_size = item.batch_size
808+
batch_size = [n, *batch_size[d:]]
809+
out = TensorDict(
810+
{k: zeros_like(_item, n, d) for k, _item in item.items()},
811+
batch_size,
812+
device=device,
813+
)
814+
return out
815+
816+
torch.manual_seed(1)
817+
td = getattr(self, td_name)(device)
818+
mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_(
819+
0.8
820+
)
821+
n = mask.sum()
822+
d = td.ndimension()
823+
pseudo_td = TensorDict(
824+
{k: zeros_like(item, n, d) for k, item in td.items()}, [n], device=device
825+
)
826+
if from_list:
827+
td_mask = mask.cpu().numpy().tolist()
828+
else:
829+
td_mask = mask
830+
if td_name == "stacked_td":
831+
with pytest.raises(RuntimeError, match="is not supported"):
832+
td[td_mask] = pseudo_td
833+
else:
834+
td[td_mask] = pseudo_td
835+
for k, item in td.items():
836+
assert (item[mask] == 0).all()
837+
785838
@pytest.mark.skipif(
786839
torch.cuda.device_count() == 0, reason="No cuda device detected"
787840
)
@@ -1779,15 +1832,9 @@ def test_stack_keys():
17791832
td.get("e")
17801833

17811834

1782-
def test_getitem_batch_size():
1783-
shape = [
1784-
10,
1785-
7,
1786-
11,
1787-
5,
1788-
]
1789-
mocking_tensor = torch.zeros(*shape)
1790-
for idx in [
1835+
@pytest.mark.parametrize(
1836+
"idx",
1837+
[
17911838
(slice(None),),
17921839
slice(None),
17931840
(3, 4),
@@ -1800,10 +1847,22 @@ def test_getitem_batch_size():
18001847
torch.tensor([0, 10, 2]),
18011848
torch.tensor([2, 4, 1]),
18021849
),
1803-
]:
1804-
expected_shape = mocking_tensor[idx].shape
1805-
resulting_shape = _getitem_batch_size(shape, idx)
1806-
assert expected_shape == resulting_shape, idx
1850+
torch.zeros(10, 7, 11, 5, dtype=torch.bool).bernoulli_(),
1851+
torch.zeros(10, 7, 11, dtype=torch.bool).bernoulli_(),
1852+
(0, torch.zeros(7, dtype=torch.bool).bernoulli_()),
1853+
],
1854+
)
1855+
def test_getitem_batch_size(idx):
1856+
shape = [
1857+
10,
1858+
7,
1859+
11,
1860+
5,
1861+
]
1862+
mocking_tensor = torch.zeros(*shape)
1863+
expected_shape = mocking_tensor[idx].shape
1864+
resulting_shape = _getitem_batch_size(shape, idx)
1865+
assert expected_shape == resulting_shape, (idx, expected_shape, resulting_shape)
18071866

18081867

18091868
@pytest.mark.parametrize("device", get_available_devices())

torchrl/data/tensordict/memmap.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ def _load_item(
262262
if idx is not None:
263263
if isinstance(idx, torch.Tensor):
264264
idx = idx.cpu()
265+
elif isinstance(idx, tuple) and any(
266+
isinstance(sub_index, torch.Tensor) for sub_index in idx
267+
):
268+
idx = tuple(
269+
sub_index.cpu()
270+
if isinstance(sub_index, torch.Tensor)
271+
else sub_index
272+
for sub_index in idx
273+
)
265274
memmap_array = memmap_array[idx]
266275
out = self._np_to_tensor(memmap_array, from_numpy=from_numpy)
267276
if (
@@ -465,6 +474,15 @@ def __setitem__(self, idx: INDEX_TYPING, value: torch.Tensor):
465474
if self.device == torch.device("cpu"):
466475
self._load_item()[idx] = value
467476
else:
477+
if isinstance(idx, torch.Tensor):
478+
idx = idx.cpu()
479+
elif isinstance(idx, tuple) and any(
480+
isinstance(_idx, torch.Tensor) for _idx in idx
481+
):
482+
idx = tuple(
483+
_idx.cpu() if isinstance(_idx, torch.Tensor) else _idx
484+
for _idx in idx
485+
)
468486
self.memmap_array[idx] = to_numpy(value)
469487

470488
def __setstate__(self, state: dict) -> None:

torchrl/data/tensordict/tensordict.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,10 @@ def masked_select(self, mask: torch.Tensor) -> _TensorDict:
10651065
"""
10661066
d = dict()
10671067
for key, value in self.items():
1068-
mask_expand = mask.squeeze(-1)
1068+
while mask.ndimension() > self.batch_dims:
1069+
mask_expand = mask.squeeze(-1)
1070+
else:
1071+
mask_expand = mask
10691072
value_select = value[mask_expand]
10701073
d[key] = value_select
10711074
dim = int(mask.sum().item())
@@ -1471,6 +1474,17 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
14711474
>>> print(td.get("a")) # values have not changed
14721475
14731476
"""
1477+
if isinstance(idx, list):
1478+
idx = torch.tensor(idx, device=self.device)
1479+
if isinstance(idx, tuple) and any(
1480+
isinstance(sub_index, list) for sub_index in idx
1481+
):
1482+
idx = tuple(
1483+
torch.tensor(sub_index, device=self.device)
1484+
if isinstance(sub_index, list)
1485+
else sub_index
1486+
for sub_index in idx
1487+
)
14741488
if isinstance(idx, str):
14751489
return self.get(idx)
14761490
if isinstance(idx, tuple) and sum(
@@ -1487,8 +1501,8 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
14871501
return out[idx[1:]]
14881502
else:
14891503
return out
1490-
elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:
1491-
return self.masked_select(idx)
1504+
# elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:
1505+
# return self.masked_select(idx)
14921506

14931507
contiguous_input = (int, slice)
14941508
return_simple_view = isinstance(idx, contiguous_input) or (
@@ -1521,6 +1535,17 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
15211535
def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None:
15221536
if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index):
15231537
index = convert_ellipsis_to_idx(index, self.batch_size)
1538+
if isinstance(index, list):
1539+
index = torch.tensor(index, device=self.device)
1540+
if isinstance(index, tuple) and any(
1541+
isinstance(sub_index, list) for sub_index in index
1542+
):
1543+
index = tuple(
1544+
torch.tensor(sub_index, device=self.device)
1545+
if isinstance(sub_index, list)
1546+
else sub_index
1547+
for sub_index in index
1548+
)
15241549
if isinstance(index, tuple) and sum(
15251550
isinstance(_index, str) for _index in index
15261551
) not in [len(index), 0]:
@@ -3291,13 +3316,49 @@ def select(self, *keys: str, inplace: bool = False) -> _TensorDict:
32913316
stack_dim=self.stack_dim,
32923317
)
32933318

3319+
def __setitem__(self, item: INDEX_TYPING, value: _TensorDict) -> _TensorDict:
3320+
if isinstance(item, list):
3321+
item = torch.tensor(item, device=self.device)
3322+
if isinstance(item, tuple) and any(
3323+
isinstance(sub_index, list) for sub_index in item
3324+
):
3325+
item = tuple(
3326+
torch.tensor(sub_index, device=self.device)
3327+
if isinstance(sub_index, list)
3328+
else sub_index
3329+
for sub_index in item
3330+
)
3331+
if (isinstance(item, torch.Tensor) and item.dtype is torch.bool) or (
3332+
isinstance(item, tuple)
3333+
and any(
3334+
isinstance(_item, torch.Tensor) and _item.dtype is torch.bool
3335+
for _item in item
3336+
)
3337+
):
3338+
raise RuntimeError(
3339+
"setting values to a LazyStackTensorDict using boolean values is not supported yet."
3340+
"If this feature is needed, feel free to raise an issue on github."
3341+
)
3342+
return super().__setitem__(item, value)
3343+
32943344
def __getitem__(self, item: INDEX_TYPING) -> _TensorDict:
32953345
if item is Ellipsis or (isinstance(item, tuple) and Ellipsis in item):
32963346
item = convert_ellipsis_to_idx(item, self.batch_size)
32973347
if isinstance(item, tuple) and sum(
32983348
isinstance(_item, str) for _item in item
32993349
) not in [len(item), 0]:
33003350
raise IndexError(_STR_MIXED_INDEX_ERROR)
3351+
if isinstance(item, list):
3352+
item = torch.tensor(item, device=self.device)
3353+
if isinstance(item, tuple) and any(
3354+
isinstance(sub_index, list) for sub_index in item
3355+
):
3356+
item = tuple(
3357+
torch.tensor(sub_index, device=self.device)
3358+
if isinstance(sub_index, list)
3359+
else sub_index
3360+
for sub_index in item
3361+
)
33013362
if isinstance(item, str):
33023363
return self.get(item)
33033364
elif isinstance(item, tuple) and all(
@@ -3761,6 +3822,17 @@ def __reduce__(self, *args, **kwargs):
37613822
return super().__reduce__(*args, **kwargs)
37623823

37633824
def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
3825+
if isinstance(idx, list):
3826+
idx = torch.tensor(idx, device=self.device)
3827+
if isinstance(idx, tuple) and any(
3828+
isinstance(sub_index, list) for sub_index in idx
3829+
):
3830+
idx = tuple(
3831+
torch.tensor(sub_index, device=self.device)
3832+
if isinstance(sub_index, list)
3833+
else sub_index
3834+
for sub_index in idx
3835+
)
37643836
if idx is Ellipsis or (isinstance(idx, tuple) and Ellipsis in idx):
37653837
idx = convert_ellipsis_to_idx(idx, self.batch_size)
37663838

torchrl/data/tensordict/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def _getitem_batch_size(
4848
items = items[0]
4949
if isinstance(items, int):
5050
return shape[1:]
51+
if isinstance(items, torch.Tensor) and items.dtype is torch.bool:
52+
return torch.Size([items.sum(), *shape[items.ndimension() :]])
5153
if (
5254
isinstance(items, (torch.Tensor, np.ndarray)) and len(items.shape) <= 1
5355
) or isinstance(items, list):
@@ -78,7 +80,10 @@ def _getitem_batch_size(
7880
v = len(range(*_item.indices(batch)))
7981
elif isinstance(_item, (list, torch.Tensor, np.ndarray)):
8082
batch = next(iter_bs)
81-
v = len(_item)
83+
if isinstance(_item, torch.Tensor) and _item.dtype is torch.bool:
84+
v = _item.sum()
85+
else:
86+
v = len(_item)
8287
elif _item is None:
8388
v = 1
8489
elif isinstance(_item, Number):

0 commit comments

Comments
 (0)