@@ -540,7 +540,7 @@ def __repr__(self):
540
540
541
541
542
542
@dataclass (repr = False )
543
- class TensorSpec :
543
+ class TensorSpec ( metaclass = abc . ABCMeta ) :
544
544
"""Parent class of the tensor meta-data containers.
545
545
546
546
TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
@@ -675,6 +675,11 @@ def encode(
675
675
self .assert_is_in (val )
676
676
return val
677
677
678
+ @abc .abstractmethod
679
+ def __eq__ (self , other : Any ) -> bool :
680
+ # Implement minimal version if super() is called
681
+ return type (self ) is type (other )
682
+
678
683
def __ne__ (self , other ):
679
684
return not (self == other )
680
685
@@ -734,13 +739,31 @@ def index(
734
739
) -> torch .Tensor | TensorDictBase :
735
740
"""Indexes the input tensor.
736
741
742
+ This method is to be used with specs that encode one or more categorical variables (e.g.,
743
+ :class:`~torchrl.data.OneHot` or :class:`~torchrl.data.Categorical`), such that indexing of a tensor
744
+ with a sample can be done without caring about the actual representation of the index.
745
+
737
746
Args:
738
747
index (int, torch.Tensor, slice or list): index of the tensor
739
748
tensor_to_index: tensor to be indexed
740
749
741
750
Returns:
742
751
indexed tensor
743
752
753
+ Exanples:
754
+ >>> from torchrl.data import OneHot
755
+ >>> import torch
756
+ >>>
757
+ >>> one_hot = OneHot(n=100)
758
+ >>> categ = one_hot.to_categorical_spec()
759
+ >>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
760
+ >>> idx_one_hot[50] = 1
761
+ >>> print(one_hot.index(idx_one_hot, torch.arange(100)))
762
+ tensor(50)
763
+ >>> idx_categ = one_hot.to_categorical(idx_one_hot)
764
+ >>> print(categ.index(idx_categ, torch.arange(100)))
765
+ tensor(50)
766
+
744
767
"""
745
768
...
746
769
@@ -1302,6 +1325,31 @@ class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
1302
1325
1303
1326
"""
1304
1327
1328
+ def _reshape (
1329
+ self ,
1330
+ * args ,
1331
+ ** kwargs ,
1332
+ ) -> Any :
1333
+ raise NotImplementedError (
1334
+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
1335
+ )
1336
+
1337
+ def cardinality (
1338
+ self ,
1339
+ * args ,
1340
+ ** kwargs ,
1341
+ ) -> Any :
1342
+ raise NotImplementedError (
1343
+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
1344
+ )
1345
+
1346
+ def index (
1347
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
1348
+ ) -> torch .Tensor | TensorDictBase :
1349
+ raise NotImplementedError (
1350
+ f"`index` is not implemented for { type (self ).__name__ } specs."
1351
+ )
1352
+
1305
1353
def __eq__ (self , other ):
1306
1354
if not isinstance (other , Stacked ):
1307
1355
return False
@@ -1823,7 +1871,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
1823
1871
f"Only tensors are allowed for indexing using "
1824
1872
f"{ self .__class__ .__name__ } .index(...)"
1825
1873
)
1826
- index = index .nonzero (). squeeze ()
1874
+ index = index .nonzero (as_tuple = True )[ - 1 ]
1827
1875
index = index .expand ((* tensor_to_index .shape [:- 1 ], index .shape [- 1 ]))
1828
1876
return tensor_to_index .gather (- 1 , index )
1829
1877
@@ -2142,6 +2190,11 @@ def __init__(
2142
2190
domain = domain ,
2143
2191
)
2144
2192
2193
+ def index (
2194
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2195
+ ) -> torch .Tensor | TensorDictBase :
2196
+ raise NotImplementedError ("Indexing not implemented for Bounded." )
2197
+
2145
2198
def enumerate (self ) -> Any :
2146
2199
raise NotImplementedError (
2147
2200
f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
@@ -2478,11 +2531,19 @@ def __eq__(self, other):
2478
2531
eq = eq & (self .example_data == getattr (other , "example_data" , None ))
2479
2532
return eq
2480
2533
2534
+ def _project (self ) -> Any :
2535
+ raise NotImplementedError ("Cannot project a NonTensorSpec." )
2536
+
2537
+ def index (
2538
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2539
+ ) -> torch .Tensor | TensorDictBase :
2540
+ raise NotImplementedError ("Cannot use index with a NonTensorSpec." )
2541
+
2481
2542
def cardinality (self ) -> Any :
2482
- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2543
+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2483
2544
2484
2545
def enumerate (self ) -> Any :
2485
- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2546
+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2486
2547
2487
2548
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
2488
2549
if isinstance (dest , torch .dtype ):
@@ -2744,6 +2805,16 @@ def __init__(
2744
2805
shape = shape , space = box , device = device , dtype = dtype , domain = domain , ** kwargs
2745
2806
)
2746
2807
2808
+ def cardinality (self ) -> int :
2809
+ raise NotImplementedError (
2810
+ "`cardinality` is not implemented for Unbounded specs."
2811
+ )
2812
+
2813
+ def index (
2814
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2815
+ ) -> torch .Tensor | TensorDictBase :
2816
+ raise NotImplementedError ("`index` is not implemented for Unbounded specs." )
2817
+
2747
2818
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Unbounded :
2748
2819
if isinstance (dest , torch .dtype ):
2749
2820
dest_dtype = dest
@@ -3515,6 +3586,14 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
3515
3586
out = torch .multinomial (mask_flat .float (), 1 ).reshape (shape_out )
3516
3587
return out
3517
3588
3589
+ def index (
3590
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3591
+ ) -> torch .Tensor | TensorDictBase :
3592
+ idx = index .expand (
3593
+ tensor_to_index .shape [: - self .ndim ] + torch .Size ([- 1 ] * self .ndim )
3594
+ )
3595
+ return tensor_to_index .gather (- 1 , idx )
3596
+
3518
3597
def _project (self , val : torch .Tensor ) -> torch .Tensor :
3519
3598
if val .dtype not in (torch .int , torch .long ):
3520
3599
val = torch .round (val )
@@ -3851,9 +3930,50 @@ def cardinality(self) -> int:
3851
3930
.item ()
3852
3931
)
3853
3932
3933
+ def enumerate (self , use_mask : bool = False ) -> List [Any ]:
3934
+ return [s for choice in self ._choices for s in choice .enumerate ()]
3935
+
3936
+ def _project (
3937
+ self , val : torch .Tensor | TensorDictBase
3938
+ ) -> torch .Tensor | TensorDictBase :
3939
+ raise NotImplementedError (
3940
+ "_project is not implemented for Choice. If this feature is required, please raise "
3941
+ "an issue on TorchRL github repo."
3942
+ )
3943
+
3944
+ def _reshape (self , shape : torch .Size ) -> T :
3945
+ return self .__class__ (
3946
+ [choice .reshape (shape ) for choice in self ._choices ],
3947
+ )
3948
+
3949
+ def index (
3950
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3951
+ ) -> torch .Tensor | TensorDictBase :
3952
+ raise NotImplementedError (
3953
+ "index is not implemented for Choice. If this feature is required, please raise "
3954
+ "an issue on TorchRL github repo."
3955
+ )
3956
+
3957
+ @property
3958
+ def num_choices (self ):
3959
+ """Number of choices for the spec."""
3960
+ return len (self ._choices )
3961
+
3854
3962
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Choice :
3855
3963
return self .__class__ ([choice .to (dest ) for choice in self ._choices ])
3856
3964
3965
+ def __eq__ (self , other ):
3966
+ if not isinstance (other , Choice ):
3967
+ return False
3968
+ if self .num_choices != other .num_choices :
3969
+ return False
3970
+ return all (
3971
+ (s0 == s1 ).all ()
3972
+ if isinstance (s0 , torch .Tensor ) or is_tensor_collection (s0 )
3973
+ else s0 == s1
3974
+ for s0 , s1 in zip (self ._choices , other ._choices )
3975
+ )
3976
+
3857
3977
3858
3978
@dataclass (repr = False )
3859
3979
class Binary (Categorical ):
@@ -4585,6 +4705,21 @@ def shape(self, value: torch.Size):
4585
4705
)
4586
4706
self ._shape = _size (value )
4587
4707
4708
+ def _project (
4709
+ self , val : torch .Tensor | TensorDictBase
4710
+ ) -> torch .Tensor | TensorDictBase :
4711
+ cls = TensorDict
4712
+ return cls .from_dict (
4713
+ {k : item ._project (val [k ]) for k , item in self .items ()},
4714
+ batch_size = self .shape ,
4715
+ device = self .device ,
4716
+ )
4717
+
4718
+ def index (
4719
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
4720
+ ) -> torch .Tensor | TensorDictBase :
4721
+ raise NotImplementedError ("`index` is not implemented for Composite specs." )
4722
+
4588
4723
def is_empty (self , recurse : bool = False ):
4589
4724
"""Whether the composite spec contains specs or not.
4590
4725
@@ -5508,6 +5643,31 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
5508
5643
5509
5644
"""
5510
5645
5646
+ def _reshape (
5647
+ self ,
5648
+ * args ,
5649
+ ** kwargs ,
5650
+ ) -> Any :
5651
+ raise NotImplementedError (
5652
+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
5653
+ )
5654
+
5655
+ def cardinality (
5656
+ self ,
5657
+ * args ,
5658
+ ** kwargs ,
5659
+ ) -> Any :
5660
+ raise NotImplementedError (
5661
+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
5662
+ )
5663
+
5664
+ def index (
5665
+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
5666
+ ) -> torch .Tensor | TensorDictBase :
5667
+ raise NotImplementedError (
5668
+ f"`index` is not implemented for { type (self ).__name__ } specs."
5669
+ )
5670
+
5511
5671
def update (self , dict ) -> None :
5512
5672
for key , item in dict .items ():
5513
5673
if key in self .keys () and isinstance (
0 commit comments