@@ -834,6 +834,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
834
834
"""
835
835
return self .is_in (item )
836
836
837
+ @abc .abstractmethod
838
+ def enumerate (self ) -> Any :
839
+ """Returns all the samples that can be obtained from the TensorSpec.
840
+
841
+ The samples will be stacked along the first dimension.
842
+
843
+ This method is only implemented for discrete specs.
844
+ """
845
+ ...
846
+
837
847
def project (
838
848
self , val : torch .Tensor | TensorDictBase
839
849
) -> torch .Tensor | TensorDictBase :
@@ -1271,6 +1281,11 @@ def __eq__(self, other):
1271
1281
return False
1272
1282
return True
1273
1283
1284
+ def enumerate (self ) -> torch .Tensor | TensorDictBase :
1285
+ return torch .stack (
1286
+ [spec .enumerate () for spec in self ._specs ], dim = self .stack_dim + 1
1287
+ )
1288
+
1274
1289
def __len__ (self ):
1275
1290
return self .shape [0 ]
1276
1291
@@ -1732,6 +1747,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
1732
1747
return np .array (vals ).reshape (tuple (val .shape ))
1733
1748
return val
1734
1749
1750
+ def enumerate (self ) -> torch .Tensor :
1751
+ return (
1752
+ torch .eye (self .n , dtype = self .dtype , device = self .device )
1753
+ .expand (* self .shape , self .n )
1754
+ .permute (- 2 , * range (self .ndimension () - 1 ), - 1 )
1755
+ )
1756
+
1735
1757
def index (self , index : INDEX_TYPING , tensor_to_index : torch .Tensor ) -> torch .Tensor :
1736
1758
if not isinstance (index , torch .Tensor ):
1737
1759
raise ValueError (
@@ -2056,6 +2078,11 @@ def __init__(
2056
2078
domain = domain ,
2057
2079
)
2058
2080
2081
+ def enumerate (self ) -> Any :
2082
+ raise NotImplementedError (
2083
+ f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
2084
+ )
2085
+
2059
2086
def __eq__ (self , other ):
2060
2087
return (
2061
2088
type (other ) == type (self )
@@ -2375,6 +2402,9 @@ def __init__(
2375
2402
shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
2376
2403
)
2377
2404
2405
+ def enumerate (self ) -> Any :
2406
+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2407
+
2378
2408
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
2379
2409
if isinstance (dest , torch .dtype ):
2380
2410
dest_dtype = dest
@@ -2611,6 +2641,9 @@ def is_in(self, val: torch.Tensor) -> bool:
2611
2641
def _project (self , val : torch .Tensor ) -> torch .Tensor :
2612
2642
return torch .as_tensor (val , dtype = self .dtype ).reshape (self .shape )
2613
2643
2644
+ def enumerate (self ) -> Any :
2645
+ raise NotImplementedError ("enumerate cannot be called with continuous specs." )
2646
+
2614
2647
def expand (self , * shape ):
2615
2648
if len (shape ) == 1 and isinstance (shape [0 ], (tuple , list , torch .Size )):
2616
2649
shape = shape [0 ]
@@ -2775,6 +2808,18 @@ def __init__(
2775
2808
)
2776
2809
self .update_mask (mask )
2777
2810
2811
+ def enumerate (self ) -> torch .Tensor :
2812
+ nvec = self .nvec
2813
+ enum_disc = self .to_categorical_spec ().enumerate ()
2814
+ enums = torch .cat (
2815
+ [
2816
+ torch .nn .functional .one_hot (enum_unb , nv ).to (self .dtype )
2817
+ for nv , enum_unb in zip (nvec , enum_disc .unbind (- 1 ))
2818
+ ],
2819
+ - 1 ,
2820
+ )
2821
+ return enums
2822
+
2778
2823
def update_mask (self , mask ):
2779
2824
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
2780
2825
@@ -3208,6 +3253,12 @@ def __init__(
3208
3253
)
3209
3254
self .update_mask (mask )
3210
3255
3256
+ def enumerate (self ) -> torch .Tensor :
3257
+ arange = torch .arange (self .n , dtype = self .dtype , device = self .device )
3258
+ if self .ndim :
3259
+ arange = arange .view (- 1 , * (1 ,) * self .ndim )
3260
+ return arange .expand (self .n , * self .shape )
3261
+
3211
3262
@property
3212
3263
def n (self ):
3213
3264
return self .space .n
@@ -3715,6 +3766,29 @@ def __init__(
3715
3766
self .update_mask (mask )
3716
3767
self .remove_singleton = remove_singleton
3717
3768
3769
+ def enumerate (self ) -> torch .Tensor :
3770
+ if self .mask is not None :
3771
+ raise RuntimeError (
3772
+ "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
3773
+ )
3774
+ if self .nvec ._base .ndim == 1 :
3775
+ nvec = self .nvec ._base
3776
+ else :
3777
+ # we have to use unique() to isolate the nvec
3778
+ nvec = self .nvec .view (- 1 , self .nvec .shape [- 1 ]).unique (dim = 0 ).squeeze (0 )
3779
+ if nvec .ndim > 1 :
3780
+ raise ValueError (
3781
+ f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={ nvec } ."
3782
+ )
3783
+ arange = torch .meshgrid (
3784
+ * [torch .arange (n , device = self .device , dtype = self .dtype ) for n in nvec ],
3785
+ indexing = "ij" ,
3786
+ )
3787
+ arange = torch .stack ([arange_ .reshape (- 1 ) for arange_ in arange ], dim = - 1 )
3788
+ arange = arange .view (arange .shape [0 ], * (1 ,) * (self .ndim - 1 ), self .shape [- 1 ])
3789
+ arange = arange .expand (arange .shape [0 ], * self .shape )
3790
+ return arange
3791
+
3718
3792
def update_mask (self , mask ):
3719
3793
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
3720
3794
@@ -3932,6 +4006,8 @@ def to_one_hot(
3932
4006
3933
4007
def to_one_hot_spec (self ) -> MultiOneHot :
3934
4008
"""Converts the spec to the equivalent one-hot spec."""
4009
+ if self .ndim > 1 :
4010
+ return torch .stack ([spec .to_one_hot_spec () for spec in self .unbind (0 )])
3935
4011
nvec = [_space .n for _space in self .space ]
3936
4012
return MultiOneHot (
3937
4013
nvec ,
@@ -4606,6 +4682,33 @@ def clone(self) -> Composite:
4606
4682
shape = self .shape ,
4607
4683
)
4608
4684
4685
+ def enumerate (self ) -> TensorDictBase :
4686
+ # We are going to use meshgrid to create samples of all the subspecs in here
4687
+ # but first let's get rid of the batch size, we'll put it back later
4688
+ self_without_batch = self
4689
+ while self_without_batch .ndim :
4690
+ self_without_batch = self_without_batch [0 ]
4691
+ samples = {key : spec .enumerate () for key , spec in self_without_batch .items ()}
4692
+ if samples :
4693
+ idx_rep = torch .meshgrid (
4694
+ * (torch .arange (s .shape [0 ]) for s in samples .values ()), indexing = "ij"
4695
+ )
4696
+ idx_rep = tuple (idx .reshape (- 1 ) for idx in idx_rep )
4697
+ samples = {
4698
+ key : sample [idx ]
4699
+ for ((key , sample ), idx ) in zip (samples .items (), idx_rep )
4700
+ }
4701
+ samples = TensorDict (
4702
+ samples , batch_size = idx_rep [0 ].shape [:1 ], device = self .device
4703
+ )
4704
+ # Expand
4705
+ if self .ndim :
4706
+ samples = samples .reshape (- 1 , * (1 ,) * self .ndim )
4707
+ samples = samples .expand (samples .shape [0 ], * self .shape )
4708
+ else :
4709
+ samples = TensorDict (batch_size = self .shape , device = self .device )
4710
+ return samples
4711
+
4609
4712
def empty (self ):
4610
4713
"""Create a spec like self, but with no entries."""
4611
4714
try :
@@ -4856,6 +4959,12 @@ def update(self, dict) -> None:
4856
4959
self [key ] = item
4857
4960
return self
4858
4961
4962
+ def enumerate (self ) -> TensorDictBase :
4963
+ dim = self .stack_dim
4964
+ return LazyStackedTensorDict .maybe_dense_stack (
4965
+ [spec .enumerate () for spec in self ._specs ], dim + 1
4966
+ )
4967
+
4859
4968
def __eq__ (self , other ):
4860
4969
if not isinstance (other , StackedComposite ):
4861
4970
return False
@@ -5150,7 +5259,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
5150
5259
5151
5260
5152
5261
@TensorSpec .implements_for_spec (torch .stack )
5153
- def _stack_specs (list_of_spec , dim , out = None ):
5262
+ def _stack_specs (list_of_spec , dim = 0 , out = None ):
5154
5263
if out is not None :
5155
5264
raise NotImplementedError (
5156
5265
"In-place spec modification is not a feature of torchrl, hence "
@@ -5187,7 +5296,7 @@ def _stack_specs(list_of_spec, dim, out=None):
5187
5296
5188
5297
5189
5298
@Composite .implements_for_spec (torch .stack )
5190
- def _stack_composite_specs (list_of_spec , dim , out = None ):
5299
+ def _stack_composite_specs (list_of_spec , dim = 0 , out = None ):
5191
5300
if out is not None :
5192
5301
raise NotImplementedError (
5193
5302
"In-place spec modification is not a feature of torchrl, hence "
0 commit comments