Skip to content

Commit c52193e

Browse files
author
Vincent Moens
committed
[BE] Ensure abstractmethods are implemented for specs
ghstack-source-id: 7b943aa Pull Request resolved: #2790 (cherry picked from commit bd78913)
1 parent 8025e4a commit c52193e

File tree

1 file changed

+164
-4
lines changed

1 file changed

+164
-4
lines changed

torchrl/data/tensor_specs.py

Lines changed: 164 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def __repr__(self):
540540

541541

542542
@dataclass(repr=False)
543-
class TensorSpec:
543+
class TensorSpec(metaclass=abc.ABCMeta):
544544
"""Parent class of the tensor meta-data containers.
545545
546546
TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
@@ -675,6 +675,11 @@ def encode(
675675
self.assert_is_in(val)
676676
return val
677677

678+
@abc.abstractmethod
679+
def __eq__(self, other: Any) -> bool:
680+
# Implement minimal version if super() is called
681+
return type(self) is type(other)
682+
678683
def __ne__(self, other):
679684
return not (self == other)
680685

@@ -734,13 +739,31 @@ def index(
734739
) -> torch.Tensor | TensorDictBase:
735740
"""Indexes the input tensor.
736741
742+
This method is to be used with specs that encode one or more categorical variables (e.g.,
743+
:class:`~torchrl.data.OneHot` or :class:`~torchrl.data.Categorical`), such that indexing of a tensor
744+
with a sample can be done without caring about the actual representation of the index.
745+
737746
Args:
738747
index (int, torch.Tensor, slice or list): index of the tensor
739748
tensor_to_index: tensor to be indexed
740749
741750
Returns:
742751
indexed tensor
743752
753+
Exanples:
754+
>>> from torchrl.data import OneHot
755+
>>> import torch
756+
>>>
757+
>>> one_hot = OneHot(n=100)
758+
>>> categ = one_hot.to_categorical_spec()
759+
>>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
760+
>>> idx_one_hot[50] = 1
761+
>>> print(one_hot.index(idx_one_hot, torch.arange(100)))
762+
tensor(50)
763+
>>> idx_categ = one_hot.to_categorical(idx_one_hot)
764+
>>> print(categ.index(idx_categ, torch.arange(100)))
765+
tensor(50)
766+
744767
"""
745768
...
746769

@@ -1302,6 +1325,31 @@ class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
13021325
13031326
"""
13041327

1328+
def _reshape(
1329+
self,
1330+
*args,
1331+
**kwargs,
1332+
) -> Any:
1333+
raise NotImplementedError(
1334+
f"`reshape` is not implemented for {type(self).__name__} specs."
1335+
)
1336+
1337+
def cardinality(
1338+
self,
1339+
*args,
1340+
**kwargs,
1341+
) -> Any:
1342+
raise NotImplementedError(
1343+
f"`cardinality` is not implemented for {type(self).__name__} specs."
1344+
)
1345+
1346+
def index(
1347+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
1348+
) -> torch.Tensor | TensorDictBase:
1349+
raise NotImplementedError(
1350+
f"`index` is not implemented for {type(self).__name__} specs."
1351+
)
1352+
13051353
def __eq__(self, other):
13061354
if not isinstance(other, Stacked):
13071355
return False
@@ -1823,7 +1871,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
18231871
f"Only tensors are allowed for indexing using "
18241872
f"{self.__class__.__name__}.index(...)"
18251873
)
1826-
index = index.nonzero().squeeze()
1874+
index = index.nonzero(as_tuple=True)[-1]
18271875
index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1]))
18281876
return tensor_to_index.gather(-1, index)
18291877

@@ -2142,6 +2190,11 @@ def __init__(
21422190
domain=domain,
21432191
)
21442192

2193+
def index(
2194+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
2195+
) -> torch.Tensor | TensorDictBase:
2196+
raise NotImplementedError("Indexing not implemented for Bounded.")
2197+
21452198
def enumerate(self) -> Any:
21462199
raise NotImplementedError(
21472200
f"enumerate is not implemented for spec of class {type(self).__name__}."
@@ -2478,11 +2531,19 @@ def __eq__(self, other):
24782531
eq = eq & (self.example_data == getattr(other, "example_data", None))
24792532
return eq
24802533

2534+
def _project(self) -> Any:
2535+
raise NotImplementedError("Cannot project a NonTensorSpec.")
2536+
2537+
def index(
2538+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
2539+
) -> torch.Tensor | TensorDictBase:
2540+
raise NotImplementedError("Cannot use index with a NonTensorSpec.")
2541+
24812542
def cardinality(self) -> Any:
2482-
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
2543+
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
24832544

24842545
def enumerate(self) -> Any:
2485-
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
2546+
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
24862547

24872548
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24882549
if isinstance(dest, torch.dtype):
@@ -2744,6 +2805,16 @@ def __init__(
27442805
shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs
27452806
)
27462807

2808+
def cardinality(self) -> int:
2809+
raise NotImplementedError(
2810+
"`cardinality` is not implemented for Unbounded specs."
2811+
)
2812+
2813+
def index(
2814+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
2815+
) -> torch.Tensor | TensorDictBase:
2816+
raise NotImplementedError("`index` is not implemented for Unbounded specs.")
2817+
27472818
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Unbounded:
27482819
if isinstance(dest, torch.dtype):
27492820
dest_dtype = dest
@@ -3515,6 +3586,14 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
35153586
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
35163587
return out
35173588

3589+
def index(
3590+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
3591+
) -> torch.Tensor | TensorDictBase:
3592+
idx = index.expand(
3593+
tensor_to_index.shape[: -self.ndim] + torch.Size([-1] * self.ndim)
3594+
)
3595+
return tensor_to_index.gather(-1, idx)
3596+
35183597
def _project(self, val: torch.Tensor) -> torch.Tensor:
35193598
if val.dtype not in (torch.int, torch.long):
35203599
val = torch.round(val)
@@ -3851,9 +3930,50 @@ def cardinality(self) -> int:
38513930
.item()
38523931
)
38533932

3933+
def enumerate(self, use_mask: bool = False) -> List[Any]:
3934+
return [s for choice in self._choices for s in choice.enumerate()]
3935+
3936+
def _project(
3937+
self, val: torch.Tensor | TensorDictBase
3938+
) -> torch.Tensor | TensorDictBase:
3939+
raise NotImplementedError(
3940+
"_project is not implemented for Choice. If this feature is required, please raise "
3941+
"an issue on TorchRL github repo."
3942+
)
3943+
3944+
def _reshape(self, shape: torch.Size) -> T:
3945+
return self.__class__(
3946+
[choice.reshape(shape) for choice in self._choices],
3947+
)
3948+
3949+
def index(
3950+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
3951+
) -> torch.Tensor | TensorDictBase:
3952+
raise NotImplementedError(
3953+
"index is not implemented for Choice. If this feature is required, please raise "
3954+
"an issue on TorchRL github repo."
3955+
)
3956+
3957+
@property
3958+
def num_choices(self):
3959+
"""Number of choices for the spec."""
3960+
return len(self._choices)
3961+
38543962
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice:
38553963
return self.__class__([choice.to(dest) for choice in self._choices])
38563964

3965+
def __eq__(self, other):
3966+
if not isinstance(other, Choice):
3967+
return False
3968+
if self.num_choices != other.num_choices:
3969+
return False
3970+
return all(
3971+
(s0 == s1).all()
3972+
if isinstance(s0, torch.Tensor) or is_tensor_collection(s0)
3973+
else s0 == s1
3974+
for s0, s1 in zip(self._choices, other._choices)
3975+
)
3976+
38573977

38583978
@dataclass(repr=False)
38593979
class Binary(Categorical):
@@ -4585,6 +4705,21 @@ def shape(self, value: torch.Size):
45854705
)
45864706
self._shape = _size(value)
45874707

4708+
def _project(
4709+
self, val: torch.Tensor | TensorDictBase
4710+
) -> torch.Tensor | TensorDictBase:
4711+
cls = TensorDict
4712+
return cls.from_dict(
4713+
{k: item._project(val[k]) for k, item in self.items()},
4714+
batch_size=self.shape,
4715+
device=self.device,
4716+
)
4717+
4718+
def index(
4719+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
4720+
) -> torch.Tensor | TensorDictBase:
4721+
raise NotImplementedError("`index` is not implemented for Composite specs.")
4722+
45884723
def is_empty(self, recurse: bool = False):
45894724
"""Whether the composite spec contains specs or not.
45904725
@@ -5508,6 +5643,31 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
55085643
55095644
"""
55105645

5646+
def _reshape(
5647+
self,
5648+
*args,
5649+
**kwargs,
5650+
) -> Any:
5651+
raise NotImplementedError(
5652+
f"`reshape` is not implemented for {type(self).__name__} specs."
5653+
)
5654+
5655+
def cardinality(
5656+
self,
5657+
*args,
5658+
**kwargs,
5659+
) -> Any:
5660+
raise NotImplementedError(
5661+
f"`cardinality` is not implemented for {type(self).__name__} specs."
5662+
)
5663+
5664+
def index(
5665+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
5666+
) -> torch.Tensor | TensorDictBase:
5667+
raise NotImplementedError(
5668+
f"`index` is not implemented for {type(self).__name__} specs."
5669+
)
5670+
55115671
def update(self, dict) -> None:
55125672
for key, item in dict.items():
55135673
if key in self.keys() and isinstance(

0 commit comments

Comments
 (0)