Skip to content

Commit 099ced3

Browse files
authored
[Feature] Rework to_one_hot and to_categorical to take a tensor as parameter (#816)
1 parent 605202f commit 099ced3

File tree

2 files changed

+123
-46
lines changed

2 files changed

+123
-46
lines changed

test/test_specs.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,9 @@ def test_mult_onehot(shape, ns):
240240
for _r, _n in zip(rsplit, ns):
241241
assert (_r.sum(-1) == 1).all()
242242
assert _r.shape[-1] == _n
243-
np_r = ts.to_numpy(r)
244-
assert not ts.is_in(torch.tensor(np_r))
245-
assert (ts.encode(np_r) == r).all()
243+
categorical = ts.to_categorical(r)
244+
assert not ts.is_in(categorical)
245+
assert (ts.encode(categorical) == r).all()
246246

247247

248248
@pytest.mark.parametrize(
@@ -327,8 +327,11 @@ def test_discrete_conversion(n, device, shape):
327327
one_hot = OneHotDiscreteTensorSpec(n, device=device, shape=shape_one_hot)
328328

329329
assert categorical != one_hot
330-
assert categorical.to_onehot() == one_hot
331-
assert one_hot.to_categorical() == categorical
330+
assert categorical.to_one_hot_spec() == one_hot
331+
assert one_hot.to_categorical_spec() == categorical
332+
333+
assert categorical.is_in(one_hot.to_categorical(one_hot.rand(shape)))
334+
assert one_hot.is_in(categorical.to_one_hot(categorical.rand(shape)))
332335

333336

334337
@pytest.mark.parametrize(
@@ -341,14 +344,24 @@ def test_discrete_conversion(n, device, shape):
341344
[4, 5, 1, 3],
342345
],
343346
)
347+
@pytest.mark.parametrize(
348+
"shape",
349+
[
350+
torch.Size([3]),
351+
torch.Size([4, 5]),
352+
],
353+
)
344354
@pytest.mark.parametrize("device", get_available_devices())
345-
def test_multi_discrete_conversion(ns, device):
355+
def test_multi_discrete_conversion(ns, shape, device):
346356
categorical = MultiDiscreteTensorSpec(ns, device=device)
347357
one_hot = MultiOneHotDiscreteTensorSpec(ns, device=device)
348358

349359
assert categorical != one_hot
350-
assert categorical.to_onehot() == one_hot
351-
assert one_hot.to_categorical() == categorical
360+
assert categorical.to_one_hot_spec() == one_hot
361+
assert one_hot.to_categorical_spec() == categorical
362+
363+
assert categorical.is_in(one_hot.to_categorical(one_hot.rand(shape)))
364+
assert one_hot.is_in(categorical.to_one_hot(categorical.rand(shape)))
352365

353366

354367
@pytest.mark.parametrize("is_complete", [True, False])
@@ -1019,21 +1032,22 @@ def test_mult_discrete_action_spec_reconstruct(self):
10191032
action_spec = MultiOneHotDiscreteTensorSpec((10, 5))
10201033

10211034
actions_tensors = [action_spec.rand() for _ in range(10)]
1022-
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
1023-
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
1035+
actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors]
1036+
actions_tensors_2 = [action_spec.encode(a) for a in actions_categorical]
10241037
assert all(
10251038
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
10261039
)
10271040

1028-
actions_numpy = [
1029-
np.concatenate(
1030-
[np.random.randint(0, 10, (1,)), np.random.randint(0, 5, (1,))], 0
1031-
)
1041+
actions_categorical = [
1042+
torch.cat((torch.randint(0, 10, (1,)), torch.randint(0, 5, (1,))), 0)
10321043
for a in actions_tensors
10331044
]
1034-
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
1035-
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
1036-
assert all((a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2))
1045+
actions_tensors = [action_spec.encode(a) for a in actions_categorical]
1046+
actions_categorical_2 = [action_spec.to_categorical(a) for a in actions_tensors]
1047+
assert all(
1048+
(a1 == a2).all()
1049+
for a1, a2 in zip(actions_categorical, actions_categorical_2)
1050+
)
10371051

10381052
def test_one_hot_discrete_action_spec_rand(self):
10391053
torch.manual_seed(0)
@@ -1070,14 +1084,14 @@ def test_mult_discrete_action_spec_rand(self):
10701084
action_spec = MultiOneHotDiscreteTensorSpec((10, 5))
10711085

10721086
actions_tensors = [action_spec.rand() for _ in range(10)]
1073-
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
1074-
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
1087+
actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors]
1088+
actions_tensors_2 = [action_spec.encode(a) for a in actions_categorical]
10751089
assert all(
10761090
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
10771091
)
10781092

1079-
sample = np.stack(
1080-
[action_spec.to_numpy(action_spec.rand()) for _ in range(N)], 0
1093+
sample = torch.stack(
1094+
[action_spec.to_categorical(action_spec.rand()) for _ in range(N)], 0
10811095
)
10821096
assert sample.shape[0] == N
10831097
assert sample.shape[1] == 2

torchrl/data/tensor_specs.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,28 @@ def __eq__(self, other):
618618
and self.use_register == other.use_register
619619
)
620620

621-
def to_categorical(self) -> DiscreteTensorSpec:
621+
def to_categorical(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
622+
"""Converts a given one-hot tensor in categorical format.
623+
624+
Args:
625+
val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
626+
safe (bool): boolean value indicating whether a check should be
627+
performed on the value against the domain of the spec.
628+
629+
Returns:
630+
The categorical tensor.
631+
"""
632+
if safe:
633+
self.assert_is_in(val)
634+
return val.argmax(-1)
635+
636+
def to_categorical_spec(self) -> DiscreteTensorSpec:
637+
"""Converts the spec to the equivalent categorical spec."""
622638
return DiscreteTensorSpec(
623-
self.space.n, device=self.device, dtype=self.dtype, shape=self.shape[:-1]
639+
self.space.n,
640+
device=self.device,
641+
dtype=self.dtype,
642+
shape=self.shape[:-1],
624643
)
625644

626645

@@ -1184,13 +1203,6 @@ def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]:
11841203
return None
11851204
return val.split(split_sizes, dim=-1)
11861205

1187-
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
1188-
if safe:
1189-
self.assert_is_in(val)
1190-
vals = self._split(val)
1191-
out = torch.stack([val.argmax(-1) for val in vals], -1).numpy()
1192-
return out
1193-
11941206
def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
11951207
if not isinstance(index, torch.Tensor):
11961208
raise ValueError(
@@ -1219,8 +1231,24 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
12191231
vals = self._split(val)
12201232
return torch.cat([super()._project(_val) for _val in vals], -1)
12211233

1222-
def to_categorical(self) -> MultiDiscreteTensorSpec:
1234+
def to_categorical(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
1235+
"""Converts a given one-hot tensor in categorical format.
1236+
1237+
Args:
1238+
val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
1239+
safe (bool): boolean value indicating whether a check should be
1240+
performed on the value against the domain of the spec.
12231241
1242+
Returns:
1243+
The categorical tensor.
1244+
"""
1245+
if safe:
1246+
self.assert_is_in(val)
1247+
vals = self._split(val)
1248+
return torch.stack([val.argmax(-1) for val in vals], -1)
1249+
1250+
def to_categorical_spec(self) -> MultiDiscreteTensorSpec:
1251+
"""Converts the spec to the equivalent categorical spec."""
12241252
return MultiDiscreteTensorSpec(
12251253
[_space.n for _space in self.space],
12261254
device=self.device,
@@ -1321,12 +1349,23 @@ def __eq__(self, other):
13211349
def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
13221350
return super().to_numpy(val, safe)
13231351

1324-
def to_onehot(self) -> OneHotDiscreteTensorSpec:
1325-
# if len(self.shape) > 1:
1326-
# raise RuntimeError(
1327-
# f"DiscreteTensorSpec with shape that has several dimensions can't be converted to "
1328-
# f"OneHotDiscreteTensorSpec. Got shape={self.shape}."
1329-
# )
1352+
def to_one_hot(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
1353+
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.
1354+
1355+
Args:
1356+
val (torch.Tensor, optional): Tensor to one-hot encode.
1357+
safe (bool): boolean value indicating whether a check should be
1358+
performed on the value against the domain of the spec.
1359+
1360+
Returns:
1361+
The one-hot encoded tensor.
1362+
"""
1363+
if safe:
1364+
self.assert_is_in(val)
1365+
return torch.nn.functional.one_hot(val, self.space.n)
1366+
1367+
def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec:
1368+
"""Converts the spec to the equivalent one-hot spec."""
13301369
shape = [*self.shape, self.space.n]
13311370
return OneHotDiscreteTensorSpec(
13321371
n=self.space.n, shape=shape, device=self.device, dtype=self.dtype
@@ -1488,17 +1527,41 @@ def is_in(self, val: torch.Tensor) -> bool:
14881527
)
14891528
if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim:
14901529
return False
1530+
val_device = val.device
1531+
return (
1532+
(
1533+
(val >= torch.zeros(self.nvec.size(), device=val_device))
1534+
& (val < self.nvec.to(val_device))
1535+
)
1536+
.all()
1537+
.item()
1538+
)
14911539

1492-
return ((val >= torch.zeros(self.nvec.size())) & (val < self.nvec)).all().item()
1540+
def to_one_hot(
1541+
self, val: torch.Tensor, safe: bool = True
1542+
) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]:
1543+
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.
14931544
1494-
def to_onehot(self) -> MultiOneHotDiscreteTensorSpec:
1495-
if len(self.shape) > 1:
1496-
raise RuntimeError(
1497-
f"DiscreteTensorSpec with shape that has several dimensions can't be converted to"
1498-
f"OneHotDiscreteTensorSpec. Got shape={self.shape}. This could be accomplished via padding or "
1499-
f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit "
1500-
f"an issue of torchrl's github repo. "
1501-
)
1545+
Args:
1546+
val (torch.Tensor, optional): Tensor to one-hot encode.
1547+
safe (bool): boolean value indicating whether a check should be
1548+
performed on the value against the domain of the spec.
1549+
1550+
Returns:
1551+
The one-hot encoded tensor.
1552+
"""
1553+
if safe:
1554+
self.assert_is_in(val)
1555+
return torch.cat(
1556+
[
1557+
torch.nn.functional.one_hot(val[..., i], n)
1558+
for i, n in enumerate(self.nvec)
1559+
],
1560+
-1,
1561+
).to(self.device)
1562+
1563+
def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec:
1564+
"""Converts the spec to the equivalent one-hot spec."""
15021565
nvec = [_space.n for _space in self.space]
15031566
return MultiOneHotDiscreteTensorSpec(
15041567
nvec,

0 commit comments

Comments
 (0)