Skip to content

Commit bd78913

Browse files
author
Vincent Moens
committed
[BE] Ensure abstractmethods are implemented for specs
ghstack-source-id: 7b943aa Pull Request resolved: #2790
1 parent 67c3e9a commit bd78913

File tree

1 file changed

+167
-4
lines changed

1 file changed

+167
-4
lines changed

torchrl/data/tensor_specs.py

Lines changed: 167 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def __repr__(self):
540540

541541

542542
@dataclass(repr=False)
543-
class TensorSpec:
543+
class TensorSpec(metaclass=abc.ABCMeta):
544544
"""Parent class of the tensor meta-data containers.
545545
546546
TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
@@ -675,6 +675,11 @@ def encode(
675675
self.assert_is_in(val)
676676
return val
677677

678+
@abc.abstractmethod
679+
def __eq__(self, other: Any) -> bool:
680+
# Implement minimal version if super() is called
681+
return type(self) is type(other)
682+
678683
def __ne__(self, other):
679684
return not (self == other)
680685

@@ -734,13 +739,31 @@ def index(
734739
) -> torch.Tensor | TensorDictBase:
735740
"""Indexes the input tensor.
736741
742+
This method is to be used with specs that encode one or more categorical variables (e.g.,
743+
:class:`~torchrl.data.OneHot` or :class:`~torchrl.data.Categorical`), such that indexing of a tensor
744+
with a sample can be done without caring about the actual representation of the index.
745+
737746
Args:
738747
index (int, torch.Tensor, slice or list): index of the tensor
739748
tensor_to_index: tensor to be indexed
740749
741750
Returns:
742751
indexed tensor
743752
753+
Exanples:
754+
>>> from torchrl.data import OneHot
755+
>>> import torch
756+
>>>
757+
>>> one_hot = OneHot(n=100)
758+
>>> categ = one_hot.to_categorical_spec()
759+
>>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
760+
>>> idx_one_hot[50] = 1
761+
>>> print(one_hot.index(idx_one_hot, torch.arange(100)))
762+
tensor(50)
763+
>>> idx_categ = one_hot.to_categorical(idx_one_hot)
764+
>>> print(categ.index(idx_categ, torch.arange(100)))
765+
tensor(50)
766+
744767
"""
745768
...
746769

@@ -1306,6 +1329,31 @@ class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
13061329
13071330
"""
13081331

1332+
def _reshape(
1333+
self,
1334+
*args,
1335+
**kwargs,
1336+
) -> Any:
1337+
raise NotImplementedError(
1338+
f"`reshape` is not implemented for {type(self).__name__} specs."
1339+
)
1340+
1341+
def cardinality(
1342+
self,
1343+
*args,
1344+
**kwargs,
1345+
) -> Any:
1346+
raise NotImplementedError(
1347+
f"`cardinality` is not implemented for {type(self).__name__} specs."
1348+
)
1349+
1350+
def index(
1351+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
1352+
) -> torch.Tensor | TensorDictBase:
1353+
raise NotImplementedError(
1354+
f"`index` is not implemented for {type(self).__name__} specs."
1355+
)
1356+
13091357
def __eq__(self, other):
13101358
if not isinstance(other, Stacked):
13111359
return False
@@ -1829,7 +1877,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
18291877
f"Only tensors are allowed for indexing using "
18301878
f"{self.__class__.__name__}.index(...)"
18311879
)
1832-
index = index.nonzero().squeeze()
1880+
index = index.nonzero(as_tuple=True)[-1]
18331881
index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1]))
18341882
return tensor_to_index.gather(-1, index)
18351883

@@ -2148,6 +2196,11 @@ def __init__(
21482196
domain=domain,
21492197
)
21502198

2199+
def index(
2200+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
2201+
) -> torch.Tensor | TensorDictBase:
2202+
raise NotImplementedError("Indexing not implemented for Bounded.")
2203+
21512204
def enumerate(self, use_mask: bool = False) -> Any:
21522205
raise NotImplementedError(
21532206
f"enumerate is not implemented for spec of class {type(self).__name__}."
@@ -2484,11 +2537,19 @@ def __eq__(self, other):
24842537
eq = eq & (self.example_data == getattr(other, "example_data", None))
24852538
return eq
24862539

2540+
def _project(self) -> Any:
2541+
raise NotImplementedError("Cannot project a NonTensorSpec.")
2542+
2543+
def index(
2544+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
2545+
) -> torch.Tensor | TensorDictBase:
2546+
raise NotImplementedError("Cannot use index with a NonTensorSpec.")
2547+
24872548
def cardinality(self) -> Any:
2488-
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
2549+
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
24892550

24902551
def enumerate(self, use_mask: bool = False) -> Any:
2491-
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
2552+
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
24922553

24932554
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24942555
if isinstance(dest, torch.dtype):
@@ -2752,6 +2813,16 @@ def __init__(
27522813
shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs
27532814
)
27542815

2816+
def cardinality(self) -> int:
2817+
raise NotImplementedError(
2818+
"`cardinality` is not implemented for Unbounded specs."
2819+
)
2820+
2821+
def index(
2822+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
2823+
) -> torch.Tensor | TensorDictBase:
2824+
raise NotImplementedError("`index` is not implemented for Unbounded specs.")
2825+
27552826
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Unbounded:
27562827
if isinstance(dest, torch.dtype):
27572828
dest_dtype = dest
@@ -3527,6 +3598,14 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
35273598
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
35283599
return out
35293600

3601+
def index(
3602+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
3603+
) -> torch.Tensor | TensorDictBase:
3604+
idx = index.expand(
3605+
tensor_to_index.shape[: -self.ndim] + torch.Size([-1] * self.ndim)
3606+
)
3607+
return tensor_to_index.gather(-1, idx)
3608+
35303609
def _project(self, val: torch.Tensor) -> torch.Tensor:
35313610
if val.dtype not in (torch.int, torch.long):
35323611
val = torch.round(val)
@@ -3863,9 +3942,50 @@ def cardinality(self) -> int:
38633942
.item()
38643943
)
38653944

3945+
def enumerate(self, use_mask: bool = False) -> List[Any]:
3946+
return [s for choice in self._choices for s in choice.enumerate()]
3947+
3948+
def _project(
3949+
self, val: torch.Tensor | TensorDictBase
3950+
) -> torch.Tensor | TensorDictBase:
3951+
raise NotImplementedError(
3952+
"_project is not implemented for Choice. If this feature is required, please raise "
3953+
"an issue on TorchRL github repo."
3954+
)
3955+
3956+
def _reshape(self, shape: torch.Size) -> T:
3957+
return self.__class__(
3958+
[choice.reshape(shape) for choice in self._choices],
3959+
)
3960+
3961+
def index(
3962+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
3963+
) -> torch.Tensor | TensorDictBase:
3964+
raise NotImplementedError(
3965+
"index is not implemented for Choice. If this feature is required, please raise "
3966+
"an issue on TorchRL github repo."
3967+
)
3968+
3969+
@property
3970+
def num_choices(self):
3971+
"""Number of choices for the spec."""
3972+
return len(self._choices)
3973+
38663974
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice:
38673975
return self.__class__([choice.to(dest) for choice in self._choices])
38683976

3977+
def __eq__(self, other):
3978+
if not isinstance(other, Choice):
3979+
return False
3980+
if self.num_choices != other.num_choices:
3981+
return False
3982+
return all(
3983+
(s0 == s1).all()
3984+
if isinstance(s0, torch.Tensor) or is_tensor_collection(s0)
3985+
else s0 == s1
3986+
for s0, s1 in zip(self._choices, other._choices)
3987+
)
3988+
38693989

38703990
@dataclass(repr=False)
38713991
class Binary(Categorical):
@@ -4643,6 +4763,24 @@ def shape(self, value: torch.Size):
46434763
)
46444764
self._shape = _size(value)
46454765

4766+
def _project(
4767+
self, val: torch.Tensor | TensorDictBase
4768+
) -> torch.Tensor | TensorDictBase:
4769+
if self.data_cls is None:
4770+
cls = TensorDict
4771+
else:
4772+
cls = self.data_cls
4773+
return cls.from_dict(
4774+
{k: item._project(val[k]) for k, item in self.items()},
4775+
batch_size=self.shape,
4776+
device=self.device,
4777+
)
4778+
4779+
def index(
4780+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
4781+
) -> torch.Tensor | TensorDictBase:
4782+
raise NotImplementedError("`index` is not implemented for Composite specs.")
4783+
46464784
def is_empty(self, recurse: bool = False):
46474785
"""Whether the composite spec contains specs or not.
46484786
@@ -5569,6 +5707,31 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
55695707
55705708
"""
55715709

5710+
def _reshape(
5711+
self,
5712+
*args,
5713+
**kwargs,
5714+
) -> Any:
5715+
raise NotImplementedError(
5716+
f"`reshape` is not implemented for {type(self).__name__} specs."
5717+
)
5718+
5719+
def cardinality(
5720+
self,
5721+
*args,
5722+
**kwargs,
5723+
) -> Any:
5724+
raise NotImplementedError(
5725+
f"`cardinality` is not implemented for {type(self).__name__} specs."
5726+
)
5727+
5728+
def index(
5729+
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
5730+
) -> torch.Tensor | TensorDictBase:
5731+
raise NotImplementedError(
5732+
f"`index` is not implemented for {type(self).__name__} specs."
5733+
)
5734+
55725735
def update(self, dict) -> None:
55735736
for key, item in dict.items():
55745737
if key in self.keys() and isinstance(

0 commit comments

Comments
 (0)