@@ -540,7 +540,7 @@ def __repr__(self):
540
540
541
541
542
542
@dataclass (repr = False )
543
- class TensorSpec :
543
+ class TensorSpec ( metaclass = abc . ABCMeta ) :
544
544
"""Parent class of the tensor meta-data containers.
545
545
546
546
TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
@@ -675,6 +675,11 @@ def encode(
675
675
self .assert_is_in (val )
676
676
return val
677
677
678
+ @abc .abstractmethod
679
+ def __eq__ (self , other : Any ) -> bool :
680
+ # Implement minimal version if super() is called
681
+ return type (self ) is type (other )
682
+
678
683
def __ne__ (self , other ):
679
684
return not (self == other )
680
685
@@ -734,13 +739,31 @@ def index(
734
739
) -> torch .Tensor | TensorDictBase :
735
740
"""Indexes the input tensor.
736
741
742
+ This method is to be used with specs that encode one or more categorical variables (e.g.,
743
+ :class:`~torchrl.data.OneHot` or :class:`~torchrl.data.Categorical`), such that indexing of a tensor
744
+ with a sample can be done without caring about the actual representation of the index.
745
+
737
746
Args:
738
747
index (int, torch.Tensor, slice or list): index of the tensor
739
748
tensor_to_index: tensor to be indexed
740
749
741
750
Returns:
742
751
indexed tensor
743
752
753
+ Exanples:
754
+ >>> from torchrl.data import OneHot
755
+ >>> import torch
756
+ >>>
757
+ >>> one_hot = OneHot(n=100)
758
+ >>> categ = one_hot.to_categorical_spec()
759
+ >>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
760
+ >>> idx_one_hot[50] = 1
761
+ >>> print(one_hot.index(idx_one_hot, torch.arange(100)))
762
+ tensor(50)
763
+ >>> idx_categ = one_hot.to_categorical(idx_one_hot)
764
+ >>> print(categ.index(idx_categ, torch.arange(100)))
765
+ tensor(50)
766
+
744
767
"""
745
768
...
746
769
@@ -1306,6 +1329,31 @@ class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
1306
1329
1307
1330
"""
1308
1331
1332
+ def _reshape (
1333
+ self ,
1334
+ * args ,
1335
+ ** kwargs ,
1336
+ ) -> Any :
1337
+ raise NotImplementedError (
1338
+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
1339
+ )
1340
+
1341
+ def cardinality (
1342
+ self ,
1343
+ * args ,
1344
+ ** kwargs ,
1345
+ ) -> Any :
1346
+ raise NotImplementedError (
1347
+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
1348
+ )
1349
+
1350
+ def index (
1351
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
1352
+ ) -> torch .Tensor | TensorDictBase :
1353
+ raise NotImplementedError (
1354
+ f"`index` is not implemented for { type (self ).__name__ } specs."
1355
+ )
1356
+
1309
1357
def __eq__ (self , other ):
1310
1358
if not isinstance (other , Stacked ):
1311
1359
return False
@@ -1829,7 +1877,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
1829
1877
f"Only tensors are allowed for indexing using "
1830
1878
f"{ self .__class__ .__name__ } .index(...)"
1831
1879
)
1832
- index = index .nonzero (). squeeze ()
1880
+ index = index .nonzero (as_tuple = True )[ - 1 ]
1833
1881
index = index .expand ((* tensor_to_index .shape [:- 1 ], index .shape [- 1 ]))
1834
1882
return tensor_to_index .gather (- 1 , index )
1835
1883
@@ -2148,6 +2196,11 @@ def __init__(
2148
2196
domain = domain ,
2149
2197
)
2150
2198
2199
+ def index (
2200
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2201
+ ) -> torch .Tensor | TensorDictBase :
2202
+ raise NotImplementedError ("Indexing not implemented for Bounded." )
2203
+
2151
2204
def enumerate (self , use_mask : bool = False ) -> Any :
2152
2205
raise NotImplementedError (
2153
2206
f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
@@ -2484,11 +2537,19 @@ def __eq__(self, other):
2484
2537
eq = eq & (self .example_data == getattr (other , "example_data" , None ))
2485
2538
return eq
2486
2539
2540
+ def _project (self ) -> Any :
2541
+ raise NotImplementedError ("Cannot project a NonTensorSpec." )
2542
+
2543
+ def index (
2544
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2545
+ ) -> torch .Tensor | TensorDictBase :
2546
+ raise NotImplementedError ("Cannot use index with a NonTensorSpec." )
2547
+
2487
2548
def cardinality (self ) -> Any :
2488
- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2549
+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2489
2550
2490
2551
def enumerate (self , use_mask : bool = False ) -> Any :
2491
- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2552
+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2492
2553
2493
2554
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
2494
2555
if isinstance (dest , torch .dtype ):
@@ -2752,6 +2813,16 @@ def __init__(
2752
2813
shape = shape , space = box , device = device , dtype = dtype , domain = domain , ** kwargs
2753
2814
)
2754
2815
2816
+ def cardinality (self ) -> int :
2817
+ raise NotImplementedError (
2818
+ "`cardinality` is not implemented for Unbounded specs."
2819
+ )
2820
+
2821
+ def index (
2822
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2823
+ ) -> torch .Tensor | TensorDictBase :
2824
+ raise NotImplementedError ("`index` is not implemented for Unbounded specs." )
2825
+
2755
2826
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Unbounded :
2756
2827
if isinstance (dest , torch .dtype ):
2757
2828
dest_dtype = dest
@@ -3527,6 +3598,14 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
3527
3598
out = torch .multinomial (mask_flat .float (), 1 ).reshape (shape_out )
3528
3599
return out
3529
3600
3601
+ def index (
3602
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3603
+ ) -> torch .Tensor | TensorDictBase :
3604
+ idx = index .expand (
3605
+ tensor_to_index .shape [: - self .ndim ] + torch .Size ([- 1 ] * self .ndim )
3606
+ )
3607
+ return tensor_to_index .gather (- 1 , idx )
3608
+
3530
3609
def _project (self , val : torch .Tensor ) -> torch .Tensor :
3531
3610
if val .dtype not in (torch .int , torch .long ):
3532
3611
val = torch .round (val )
@@ -3863,9 +3942,50 @@ def cardinality(self) -> int:
3863
3942
.item ()
3864
3943
)
3865
3944
3945
+ def enumerate (self , use_mask : bool = False ) -> List [Any ]:
3946
+ return [s for choice in self ._choices for s in choice .enumerate ()]
3947
+
3948
+ def _project (
3949
+ self , val : torch .Tensor | TensorDictBase
3950
+ ) -> torch .Tensor | TensorDictBase :
3951
+ raise NotImplementedError (
3952
+ "_project is not implemented for Choice. If this feature is required, please raise "
3953
+ "an issue on TorchRL github repo."
3954
+ )
3955
+
3956
+ def _reshape (self , shape : torch .Size ) -> T :
3957
+ return self .__class__ (
3958
+ [choice .reshape (shape ) for choice in self ._choices ],
3959
+ )
3960
+
3961
+ def index (
3962
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3963
+ ) -> torch .Tensor | TensorDictBase :
3964
+ raise NotImplementedError (
3965
+ "index is not implemented for Choice. If this feature is required, please raise "
3966
+ "an issue on TorchRL github repo."
3967
+ )
3968
+
3969
+ @property
3970
+ def num_choices (self ):
3971
+ """Number of choices for the spec."""
3972
+ return len (self ._choices )
3973
+
3866
3974
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Choice :
3867
3975
return self .__class__ ([choice .to (dest ) for choice in self ._choices ])
3868
3976
3977
+ def __eq__ (self , other ):
3978
+ if not isinstance (other , Choice ):
3979
+ return False
3980
+ if self .num_choices != other .num_choices :
3981
+ return False
3982
+ return all (
3983
+ (s0 == s1 ).all ()
3984
+ if isinstance (s0 , torch .Tensor ) or is_tensor_collection (s0 )
3985
+ else s0 == s1
3986
+ for s0 , s1 in zip (self ._choices , other ._choices )
3987
+ )
3988
+
3869
3989
3870
3990
@dataclass (repr = False )
3871
3991
class Binary (Categorical ):
@@ -4643,6 +4763,24 @@ def shape(self, value: torch.Size):
4643
4763
)
4644
4764
self ._shape = _size (value )
4645
4765
4766
+ def _project (
4767
+ self , val : torch .Tensor | TensorDictBase
4768
+ ) -> torch .Tensor | TensorDictBase :
4769
+ if self .data_cls is None :
4770
+ cls = TensorDict
4771
+ else :
4772
+ cls = self .data_cls
4773
+ return cls .from_dict (
4774
+ {k : item ._project (val [k ]) for k , item in self .items ()},
4775
+ batch_size = self .shape ,
4776
+ device = self .device ,
4777
+ )
4778
+
4779
+ def index (
4780
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
4781
+ ) -> torch .Tensor | TensorDictBase :
4782
+ raise NotImplementedError ("`index` is not implemented for Composite specs." )
4783
+
4646
4784
def is_empty (self , recurse : bool = False ):
4647
4785
"""Whether the composite spec contains specs or not.
4648
4786
@@ -5569,6 +5707,31 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
5569
5707
5570
5708
"""
5571
5709
5710
+ def _reshape (
5711
+ self ,
5712
+ * args ,
5713
+ ** kwargs ,
5714
+ ) -> Any :
5715
+ raise NotImplementedError (
5716
+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
5717
+ )
5718
+
5719
+ def cardinality (
5720
+ self ,
5721
+ * args ,
5722
+ ** kwargs ,
5723
+ ) -> Any :
5724
+ raise NotImplementedError (
5725
+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
5726
+ )
5727
+
5728
+ def index (
5729
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
5730
+ ) -> torch .Tensor | TensorDictBase :
5731
+ raise NotImplementedError (
5732
+ f"`index` is not implemented for { type (self ).__name__ } specs."
5733
+ )
5734
+
5572
5735
def update (self , dict ) -> None :
5573
5736
for key , item in dict .items ():
5574
5737
if key in self .keys () and isinstance (
0 commit comments