Skip to content

Commit dc1584d

Browse files
authored
[BugFix] Fix MultOneHotDiscreteTensorSpec.is_in (#818)
1 parent 6655b2d commit dc1584d

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

test/test_tensor_spec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def test_mult_onehot(shape, ns):
231231
assert (_r.sum(-1) == 1).all()
232232
assert _r.shape[-1] == _n
233233
np_r = ts.to_numpy(r)
234+
assert not ts.is_in(torch.tensor(np_r))
234235
assert (ts.encode(np_r) == r).all()
235236

236237

torchrl/data/tensor_specs.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,11 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
878878
x.append(super(MultOneHotDiscreteTensorSpec, self).encode(v, space))
879879
return torch.cat(x, -1)
880880

881-
def _split(self, val: torch.Tensor) -> torch.Tensor:
882-
vals = val.split([space.n for space in self.space], dim=-1)
883-
return vals
881+
def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]:
882+
split_sizes = [space.n for space in self.space]
883+
if val.ndim < 1 or val.shape[-1] != sum(split_sizes):
884+
return None
885+
return val.split(split_sizes, dim=-1)
884886

885887
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
886888
if safe:
@@ -907,8 +909,10 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
907909

908910
def is_in(self, val: torch.Tensor) -> bool:
909911
vals = self._split(val)
912+
if vals is None:
913+
return False
910914
return all(
911-
[super(MultOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals]
915+
super(MultOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals
912916
)
913917

914918
def _project(self, val: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)