Skip to content

Commit 14b63e4

Browse files
author
Vincent Moens
committed
[Feature] TensorSpec.enumerate()
ghstack-source-id: 9db2f5e Pull Request resolved: #2354
1 parent 8a8b4c3 commit 14b63e4

File tree

2 files changed

+163
-2
lines changed

2 files changed

+163
-2
lines changed

test/test_specs.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,58 @@ def test_non_tensor(self):
38163816
assert not isinstance(non_tensor, MultiOneHot)
38173817

38183818

3819+
class TestSpecEnumerate:
3820+
def test_discrete(self):
3821+
spec = DiscreteTensorSpec(n=5, shape=(3,))
3822+
assert (
3823+
spec.enumerate()
3824+
== torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
3825+
).all()
3826+
assert spec.is_in(spec.enumerate())
3827+
3828+
def test_one_hot(self):
3829+
spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5))
3830+
assert (
3831+
spec.enumerate()
3832+
== torch.tensor(
3833+
[
3834+
[[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]],
3835+
[[0, 1, 0, 0, 0], [0, 1, 0, 0, 0]],
3836+
[[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]],
3837+
[[0, 0, 0, 1, 0], [0, 0, 0, 1, 0]],
3838+
[[0, 0, 0, 0, 1], [0, 0, 0, 0, 1]],
3839+
],
3840+
dtype=torch.bool,
3841+
)
3842+
).all()
3843+
assert spec.is_in(spec.enumerate())
3844+
3845+
def test_multi_discrete(self):
3846+
spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3))
3847+
enum = spec.enumerate()
3848+
assert spec.is_in(enum)
3849+
assert enum.shape == torch.Size([60, 2, 3])
3850+
3851+
def test_multi_onehot(self):
3852+
spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12))
3853+
enum = spec.enumerate()
3854+
assert spec.is_in(enum)
3855+
assert enum.shape == torch.Size([60, 2, 12])
3856+
3857+
def test_composite(self):
3858+
c = CompositeSpec(
3859+
{
3860+
"a": OneHotDiscreteTensorSpec(n=5, shape=(3, 5)),
3861+
("b", "c"): DiscreteTensorSpec(n=4, shape=(3,)),
3862+
},
3863+
shape=[3],
3864+
)
3865+
c_enum = c.enumerate()
3866+
assert c.is_in(c_enum)
3867+
assert c_enum.shape == torch.Size((20, 3))
3868+
assert c_enum["b"].shape == torch.Size((20, 3))
3869+
3870+
38193871
if __name__ == "__main__":
38203872
args, unknown = argparse.ArgumentParser().parse_known_args()
38213873
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/tensor_specs.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
834834
"""
835835
return self.is_in(item)
836836

837+
@abc.abstractmethod
838+
def enumerate(self) -> Any:
839+
"""Returns all the samples that can be obtained from the TensorSpec.
840+
841+
The samples will be stacked along the first dimension.
842+
843+
This method is only implemented for discrete specs.
844+
"""
845+
...
846+
837847
def project(
838848
self, val: torch.Tensor | TensorDictBase
839849
) -> torch.Tensor | TensorDictBase:
@@ -1271,6 +1281,11 @@ def __eq__(self, other):
12711281
return False
12721282
return True
12731283

1284+
def enumerate(self) -> torch.Tensor | TensorDictBase:
1285+
return torch.stack(
1286+
[spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1
1287+
)
1288+
12741289
def __len__(self):
12751290
return self.shape[0]
12761291

@@ -1732,6 +1747,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
17321747
return np.array(vals).reshape(tuple(val.shape))
17331748
return val
17341749

1750+
def enumerate(self) -> torch.Tensor:
1751+
return (
1752+
torch.eye(self.n, dtype=self.dtype, device=self.device)
1753+
.expand(*self.shape, self.n)
1754+
.permute(-2, *range(self.ndimension() - 1), -1)
1755+
)
1756+
17351757
def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
17361758
if not isinstance(index, torch.Tensor):
17371759
raise ValueError(
@@ -2056,6 +2078,11 @@ def __init__(
20562078
domain=domain,
20572079
)
20582080

2081+
def enumerate(self) -> Any:
2082+
raise NotImplementedError(
2083+
f"enumerate is not implemented for spec of class {type(self).__name__}."
2084+
)
2085+
20592086
def __eq__(self, other):
20602087
return (
20612088
type(other) == type(self)
@@ -2375,6 +2402,9 @@ def __init__(
23752402
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
23762403
)
23772404

2405+
def enumerate(self) -> Any:
2406+
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
2407+
23782408
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
23792409
if isinstance(dest, torch.dtype):
23802410
dest_dtype = dest
@@ -2611,6 +2641,9 @@ def is_in(self, val: torch.Tensor) -> bool:
26112641
def _project(self, val: torch.Tensor) -> torch.Tensor:
26122642
return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape)
26132643

2644+
def enumerate(self) -> Any:
2645+
raise NotImplementedError("enumerate cannot be called with continuous specs.")
2646+
26142647
def expand(self, *shape):
26152648
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
26162649
shape = shape[0]
@@ -2775,6 +2808,18 @@ def __init__(
27752808
)
27762809
self.update_mask(mask)
27772810

2811+
def enumerate(self) -> torch.Tensor:
2812+
nvec = self.nvec
2813+
enum_disc = self.to_categorical_spec().enumerate()
2814+
enums = torch.cat(
2815+
[
2816+
torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype)
2817+
for nv, enum_unb in zip(nvec, enum_disc.unbind(-1))
2818+
],
2819+
-1,
2820+
)
2821+
return enums
2822+
27782823
def update_mask(self, mask):
27792824
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
27802825
@@ -3208,6 +3253,12 @@ def __init__(
32083253
)
32093254
self.update_mask(mask)
32103255

3256+
def enumerate(self) -> torch.Tensor:
3257+
arange = torch.arange(self.n, dtype=self.dtype, device=self.device)
3258+
if self.ndim:
3259+
arange = arange.view(-1, *(1,) * self.ndim)
3260+
return arange.expand(self.n, *self.shape)
3261+
32113262
@property
32123263
def n(self):
32133264
return self.space.n
@@ -3715,6 +3766,29 @@ def __init__(
37153766
self.update_mask(mask)
37163767
self.remove_singleton = remove_singleton
37173768

3769+
def enumerate(self) -> torch.Tensor:
3770+
if self.mask is not None:
3771+
raise RuntimeError(
3772+
"Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
3773+
)
3774+
if self.nvec._base.ndim == 1:
3775+
nvec = self.nvec._base
3776+
else:
3777+
# we have to use unique() to isolate the nvec
3778+
nvec = self.nvec.view(-1, self.nvec.shape[-1]).unique(dim=0).squeeze(0)
3779+
if nvec.ndim > 1:
3780+
raise ValueError(
3781+
f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={nvec}."
3782+
)
3783+
arange = torch.meshgrid(
3784+
*[torch.arange(n, device=self.device, dtype=self.dtype) for n in nvec],
3785+
indexing="ij",
3786+
)
3787+
arange = torch.stack([arange_.reshape(-1) for arange_ in arange], dim=-1)
3788+
arange = arange.view(arange.shape[0], *(1,) * (self.ndim - 1), self.shape[-1])
3789+
arange = arange.expand(arange.shape[0], *self.shape)
3790+
return arange
3791+
37183792
def update_mask(self, mask):
37193793
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
37203794
@@ -3932,6 +4006,8 @@ def to_one_hot(
39324006

39334007
def to_one_hot_spec(self) -> MultiOneHot:
39344008
"""Converts the spec to the equivalent one-hot spec."""
4009+
if self.ndim > 1:
4010+
return torch.stack([spec.to_one_hot_spec() for spec in self.unbind(0)])
39354011
nvec = [_space.n for _space in self.space]
39364012
return MultiOneHot(
39374013
nvec,
@@ -4606,6 +4682,33 @@ def clone(self) -> Composite:
46064682
shape=self.shape,
46074683
)
46084684

4685+
def enumerate(self) -> TensorDictBase:
4686+
# We are going to use meshgrid to create samples of all the subspecs in here
4687+
# but first let's get rid of the batch size, we'll put it back later
4688+
self_without_batch = self
4689+
while self_without_batch.ndim:
4690+
self_without_batch = self_without_batch[0]
4691+
samples = {key: spec.enumerate() for key, spec in self_without_batch.items()}
4692+
if samples:
4693+
idx_rep = torch.meshgrid(
4694+
*(torch.arange(s.shape[0]) for s in samples.values()), indexing="ij"
4695+
)
4696+
idx_rep = tuple(idx.reshape(-1) for idx in idx_rep)
4697+
samples = {
4698+
key: sample[idx]
4699+
for ((key, sample), idx) in zip(samples.items(), idx_rep)
4700+
}
4701+
samples = TensorDict(
4702+
samples, batch_size=idx_rep[0].shape[:1], device=self.device
4703+
)
4704+
# Expand
4705+
if self.ndim:
4706+
samples = samples.reshape(-1, *(1,) * self.ndim)
4707+
samples = samples.expand(samples.shape[0], *self.shape)
4708+
else:
4709+
samples = TensorDict(batch_size=self.shape, device=self.device)
4710+
return samples
4711+
46094712
def empty(self):
46104713
"""Create a spec like self, but with no entries."""
46114714
try:
@@ -4856,6 +4959,12 @@ def update(self, dict) -> None:
48564959
self[key] = item
48574960
return self
48584961

4962+
def enumerate(self) -> TensorDictBase:
4963+
dim = self.stack_dim
4964+
return LazyStackedTensorDict.maybe_dense_stack(
4965+
[spec.enumerate() for spec in self._specs], dim + 1
4966+
)
4967+
48594968
def __eq__(self, other):
48604969
if not isinstance(other, StackedComposite):
48614970
return False
@@ -5150,7 +5259,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
51505259

51515260

51525261
@TensorSpec.implements_for_spec(torch.stack)
5153-
def _stack_specs(list_of_spec, dim, out=None):
5262+
def _stack_specs(list_of_spec, dim=0, out=None):
51545263
if out is not None:
51555264
raise NotImplementedError(
51565265
"In-place spec modification is not a feature of torchrl, hence "
@@ -5187,7 +5296,7 @@ def _stack_specs(list_of_spec, dim, out=None):
51875296

51885297

51895298
@Composite.implements_for_spec(torch.stack)
5190-
def _stack_composite_specs(list_of_spec, dim, out=None):
5299+
def _stack_composite_specs(list_of_spec, dim=0, out=None):
51915300
if out is not None:
51925301
raise NotImplementedError(
51935302
"In-place spec modification is not a feature of torchrl, hence "

0 commit comments

Comments
 (0)