@@ -1239,6 +1239,9 @@ class GAE(ValueEstimatorBase):
1239
1239
:meth:`~.value_estimate`.
1240
1240
Negative dimensions are considered with respect to the input
1241
1241
tensordict.
1242
+ auto_reset_env (bool, optional): if ``True``, the last ``"next"`` state
1243
+ of the episode isn't valid, so the GAE calculation will use the ``value``
1244
+ instead of ``next_value`` to bootstrap truncated episodes.
1242
1245
1243
1246
GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also
1244
1247
return a :obj:`"value_target"` entry with the return value that is to be used
@@ -1270,6 +1273,7 @@ def __init__(
1270
1273
shifted : bool = False ,
1271
1274
device : torch .device | None = None ,
1272
1275
time_dim : int | None = None ,
1276
+ auto_reset_env : bool = False ,
1273
1277
):
1274
1278
super ().__init__ (
1275
1279
shifted = shifted ,
@@ -1296,6 +1300,7 @@ def __init__(
1296
1300
self .average_gae = average_gae
1297
1301
self .vectorized = vectorized
1298
1302
self .time_dim = time_dim
1303
+ self .auto_reset_env = auto_reset_env
1299
1304
1300
1305
@property
1301
1306
def vectorized (self ):
@@ -1430,6 +1435,12 @@ def forward(
1430
1435
done = tensordict .get (("next" , self .tensor_keys .done ))
1431
1436
terminated = tensordict .get (("next" , self .tensor_keys .terminated ), default = done )
1432
1437
time_dim = self ._get_time_dim (time_dim , tensordict )
1438
+
1439
+ if self .auto_reset_env :
1440
+ truncated = tensordict .get (("next" , "truncated" ))
1441
+ if truncated .any ():
1442
+ reward += gamma * value * truncated
1443
+
1433
1444
if self .vectorized :
1434
1445
adv , value_target = vec_generalized_advantage_estimate (
1435
1446
gamma ,
@@ -1438,7 +1449,7 @@ def forward(
1438
1449
next_value ,
1439
1450
reward ,
1440
1451
done = done ,
1441
- terminated = terminated ,
1452
+ terminated = terminated if not self . auto_reset_env else done ,
1442
1453
time_dim = time_dim ,
1443
1454
)
1444
1455
else :
@@ -1449,7 +1460,7 @@ def forward(
1449
1460
next_value ,
1450
1461
reward ,
1451
1462
done = done ,
1452
- terminated = terminated ,
1463
+ terminated = terminated if not self . auto_reset_env else done ,
1453
1464
time_dim = time_dim ,
1454
1465
)
1455
1466
0 commit comments