@@ -869,12 +869,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
869
869
return self .is_in (item )
870
870
871
871
@abc .abstractmethod
872
- def enumerate (self ) -> Any :
872
+ def enumerate (self , use_mask : bool = False ) -> Any :
873
873
"""Returns all the samples that can be obtained from the TensorSpec.
874
874
875
875
The samples will be stacked along the first dimension.
876
876
877
877
This method is only implemented for discrete specs.
878
+
879
+ Args:
880
+ use_mask (bool, optional): If ``True`` and the spec has a mask,
881
+ samples that are masked are excluded. Default is ``False``.
878
882
"""
879
883
...
880
884
@@ -1315,9 +1319,9 @@ def __eq__(self, other):
1315
1319
return False
1316
1320
return True
1317
1321
1318
- def enumerate (self ) -> torch .Tensor | TensorDictBase :
1322
+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor | TensorDictBase :
1319
1323
return torch .stack (
1320
- [spec .enumerate () for spec in self ._specs ], dim = self .stack_dim + 1
1324
+ [spec .enumerate (use_mask ) for spec in self ._specs ], dim = self .stack_dim + 1
1321
1325
)
1322
1326
1323
1327
def __len__ (self ):
@@ -1810,7 +1814,9 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
1810
1814
return np .array (vals ).reshape (tuple (val .shape ))
1811
1815
return val
1812
1816
1813
- def enumerate (self ) -> torch .Tensor :
1817
+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
1818
+ if use_mask :
1819
+ raise NotImplementedError
1814
1820
return (
1815
1821
torch .eye (self .n , dtype = self .dtype , device = self .device )
1816
1822
.expand (* self .shape , self .n )
@@ -2142,7 +2148,7 @@ def __init__(
2142
2148
domain = domain ,
2143
2149
)
2144
2150
2145
- def enumerate (self ) -> Any :
2151
+ def enumerate (self , use_mask : bool = False ) -> Any :
2146
2152
raise NotImplementedError (
2147
2153
f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
2148
2154
)
@@ -2481,7 +2487,7 @@ def __eq__(self, other):
2481
2487
def cardinality (self ) -> Any :
2482
2488
raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2483
2489
2484
- def enumerate (self ) -> Any :
2490
+ def enumerate (self , use_mask : bool = False ) -> Any :
2485
2491
raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2486
2492
2487
2493
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
@@ -2779,7 +2785,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
2779
2785
val .shape [: - self .ndim ] + self .shape
2780
2786
)
2781
2787
2782
- def enumerate (self ) -> Any :
2788
+ def enumerate (self , use_mask : bool = False ) -> Any :
2783
2789
raise NotImplementedError ("enumerate cannot be called with continuous specs." )
2784
2790
2785
2791
def expand (self , * shape ):
@@ -2951,9 +2957,9 @@ def __init__(
2951
2957
def cardinality (self ) -> int :
2952
2958
return torch .as_tensor (self .nvec ).prod ()
2953
2959
2954
- def enumerate (self ) -> torch .Tensor :
2960
+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
2955
2961
nvec = self .nvec
2956
- enum_disc = self .to_categorical_spec ().enumerate ()
2962
+ enum_disc = self .to_categorical_spec ().enumerate (use_mask )
2957
2963
enums = torch .cat (
2958
2964
[
2959
2965
torch .nn .functional .one_hot (enum_unb , nv ).to (self .dtype )
@@ -3417,14 +3423,18 @@ def __init__(
3417
3423
def _undefined_n (self ):
3418
3424
return self .space .n < 0
3419
3425
3420
- def enumerate (self ) -> torch .Tensor :
3426
+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
3421
3427
dtype = self .dtype
3422
3428
if dtype is torch .bool :
3423
3429
dtype = torch .uint8
3424
- arange = torch .arange (self .n , dtype = dtype , device = self .device )
3430
+ n = self .n
3431
+ arange = torch .arange (n , dtype = dtype , device = self .device )
3432
+ if use_mask and self .mask is not None :
3433
+ arange = arange [self .mask ]
3434
+ n = arange .shape [0 ]
3425
3435
if self .ndim :
3426
3436
arange = arange .view (- 1 , * (1 ,) * self .ndim )
3427
- return arange .expand (self . n , * self .shape )
3437
+ return arange .expand (n , * self .shape )
3428
3438
3429
3439
@property
3430
3440
def n (self ):
@@ -4088,7 +4098,9 @@ def __init__(
4088
4098
self .update_mask (mask )
4089
4099
self .remove_singleton = remove_singleton
4090
4100
4091
- def enumerate (self ) -> torch .Tensor :
4101
+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
4102
+ if use_mask :
4103
+ raise NotImplementedError ()
4092
4104
if self .mask is not None :
4093
4105
raise RuntimeError (
4094
4106
"Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
@@ -5136,13 +5148,15 @@ def cardinality(self) -> int:
5136
5148
n = 0
5137
5149
return n
5138
5150
5139
- def enumerate (self ) -> TensorDictBase :
5151
+ def enumerate (self , use_mask : bool = False ) -> TensorDictBase :
5140
5152
# We are going to use meshgrid to create samples of all the subspecs in here
5141
5153
# but first let's get rid of the batch size, we'll put it back later
5142
5154
self_without_batch = self
5143
5155
while self_without_batch .ndim :
5144
5156
self_without_batch = self_without_batch [0 ]
5145
- samples = {key : spec .enumerate () for key , spec in self_without_batch .items ()}
5157
+ samples = {
5158
+ key : spec .enumerate (use_mask ) for key , spec in self_without_batch .items ()
5159
+ }
5146
5160
if self .data_cls is not None :
5147
5161
cls = self .data_cls
5148
5162
else :
@@ -5566,10 +5580,10 @@ def update(self, dict) -> None:
5566
5580
self [key ] = item
5567
5581
return self
5568
5582
5569
- def enumerate (self ) -> TensorDictBase :
5583
+ def enumerate (self , use_mask : bool = False ) -> TensorDictBase :
5570
5584
dim = self .stack_dim
5571
5585
return LazyStackedTensorDict .maybe_dense_stack (
5572
- [spec .enumerate () for spec in self ._specs ], dim + 1
5586
+ [spec .enumerate (use_mask ) for spec in self ._specs ], dim + 1
5573
5587
)
5574
5588
5575
5589
def __eq__ (self , other ):
0 commit comments