@@ -618,9 +618,28 @@ def __eq__(self, other):
618
618
and self .use_register == other .use_register
619
619
)
620
620
621
- def to_categorical (self ) -> DiscreteTensorSpec :
621
+ def to_categorical (self , val : torch .Tensor , safe : bool = True ) -> torch .Tensor :
622
+ """Converts a given one-hot tensor in categorical format.
623
+
624
+ Args:
625
+ val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
626
+ safe (bool): boolean value indicating whether a check should be
627
+ performed on the value against the domain of the spec.
628
+
629
+ Returns:
630
+ The categorical tensor.
631
+ """
632
+ if safe :
633
+ self .assert_is_in (val )
634
+ return val .argmax (- 1 )
635
+
636
+ def to_categorical_spec (self ) -> DiscreteTensorSpec :
637
+ """Converts the spec to the equivalent categorical spec."""
622
638
return DiscreteTensorSpec (
623
- self .space .n , device = self .device , dtype = self .dtype , shape = self .shape [:- 1 ]
639
+ self .space .n ,
640
+ device = self .device ,
641
+ dtype = self .dtype ,
642
+ shape = self .shape [:- 1 ],
624
643
)
625
644
626
645
@@ -1184,13 +1203,6 @@ def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]:
1184
1203
return None
1185
1204
return val .split (split_sizes , dim = - 1 )
1186
1205
1187
- def to_numpy (self , val : torch .Tensor , safe : bool = True ) -> np .ndarray :
1188
- if safe :
1189
- self .assert_is_in (val )
1190
- vals = self ._split (val )
1191
- out = torch .stack ([val .argmax (- 1 ) for val in vals ], - 1 ).numpy ()
1192
- return out
1193
-
1194
1206
def index (self , index : INDEX_TYPING , tensor_to_index : torch .Tensor ) -> torch .Tensor :
1195
1207
if not isinstance (index , torch .Tensor ):
1196
1208
raise ValueError (
@@ -1219,8 +1231,24 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
1219
1231
vals = self ._split (val )
1220
1232
return torch .cat ([super ()._project (_val ) for _val in vals ], - 1 )
1221
1233
1222
- def to_categorical (self ) -> MultiDiscreteTensorSpec :
1234
+ def to_categorical (self , val : torch .Tensor , safe : bool = True ) -> torch .Tensor :
1235
+ """Converts a given one-hot tensor in categorical format.
1236
+
1237
+ Args:
1238
+ val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
1239
+ safe (bool): boolean value indicating whether a check should be
1240
+ performed on the value against the domain of the spec.
1223
1241
1242
+ Returns:
1243
+ The categorical tensor.
1244
+ """
1245
+ if safe :
1246
+ self .assert_is_in (val )
1247
+ vals = self ._split (val )
1248
+ return torch .stack ([val .argmax (- 1 ) for val in vals ], - 1 )
1249
+
1250
+ def to_categorical_spec (self ) -> MultiDiscreteTensorSpec :
1251
+ """Converts the spec to the equivalent categorical spec."""
1224
1252
return MultiDiscreteTensorSpec (
1225
1253
[_space .n for _space in self .space ],
1226
1254
device = self .device ,
@@ -1321,12 +1349,23 @@ def __eq__(self, other):
1321
1349
def to_numpy (self , val : TensorDict , safe : bool = True ) -> dict :
1322
1350
return super ().to_numpy (val , safe )
1323
1351
1324
- def to_onehot (self ) -> OneHotDiscreteTensorSpec :
1325
- # if len(self.shape) > 1:
1326
- # raise RuntimeError(
1327
- # f"DiscreteTensorSpec with shape that has several dimensions can't be converted to "
1328
- # f"OneHotDiscreteTensorSpec. Got shape={self.shape}."
1329
- # )
1352
+ def to_one_hot (self , val : torch .Tensor , safe : bool = True ) -> torch .Tensor :
1353
+ """Encodes a discrete tensor from the spec domain into its one-hot correspondent.
1354
+
1355
+ Args:
1356
+ val (torch.Tensor, optional): Tensor to one-hot encode.
1357
+ safe (bool): boolean value indicating whether a check should be
1358
+ performed on the value against the domain of the spec.
1359
+
1360
+ Returns:
1361
+ The one-hot encoded tensor.
1362
+ """
1363
+ if safe :
1364
+ self .assert_is_in (val )
1365
+ return torch .nn .functional .one_hot (val , self .space .n )
1366
+
1367
+ def to_one_hot_spec (self ) -> OneHotDiscreteTensorSpec :
1368
+ """Converts the spec to the equivalent one-hot spec."""
1330
1369
shape = [* self .shape , self .space .n ]
1331
1370
return OneHotDiscreteTensorSpec (
1332
1371
n = self .space .n , shape = shape , device = self .device , dtype = self .dtype
@@ -1488,17 +1527,41 @@ def is_in(self, val: torch.Tensor) -> bool:
1488
1527
)
1489
1528
if self .dtype != val .dtype or len (self .shape ) > val .ndim or val_have_wrong_dim :
1490
1529
return False
1530
+ val_device = val .device
1531
+ return (
1532
+ (
1533
+ (val >= torch .zeros (self .nvec .size (), device = val_device ))
1534
+ & (val < self .nvec .to (val_device ))
1535
+ )
1536
+ .all ()
1537
+ .item ()
1538
+ )
1491
1539
1492
- return ((val >= torch .zeros (self .nvec .size ())) & (val < self .nvec )).all ().item ()
1540
+ def to_one_hot (
1541
+ self , val : torch .Tensor , safe : bool = True
1542
+ ) -> Union [MultiOneHotDiscreteTensorSpec , torch .Tensor ]:
1543
+ """Encodes a discrete tensor from the spec domain into its one-hot correspondent.
1493
1544
1494
- def to_onehot (self ) -> MultiOneHotDiscreteTensorSpec :
1495
- if len (self .shape ) > 1 :
1496
- raise RuntimeError (
1497
- f"DiscreteTensorSpec with shape that has several dimensions can't be converted to"
1498
- f"OneHotDiscreteTensorSpec. Got shape={ self .shape } . This could be accomplished via padding or "
1499
- f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit "
1500
- f"an issue of torchrl's github repo. "
1501
- )
1545
+ Args:
1546
+ val (torch.Tensor, optional): Tensor to one-hot encode.
1547
+ safe (bool): boolean value indicating whether a check should be
1548
+ performed on the value against the domain of the spec.
1549
+
1550
+ Returns:
1551
+ The one-hot encoded tensor.
1552
+ """
1553
+ if safe :
1554
+ self .assert_is_in (val )
1555
+ return torch .cat (
1556
+ [
1557
+ torch .nn .functional .one_hot (val [..., i ], n )
1558
+ for i , n in enumerate (self .nvec )
1559
+ ],
1560
+ - 1 ,
1561
+ ).to (self .device )
1562
+
1563
+ def to_one_hot_spec (self ) -> MultiOneHotDiscreteTensorSpec :
1564
+ """Converts the spec to the equivalent one-hot spec."""
1502
1565
nvec = [_space .n for _space in self .space ]
1503
1566
return MultiOneHotDiscreteTensorSpec (
1504
1567
nvec ,
0 commit comments