@@ -126,6 +126,10 @@ class SACLoss(LossModule):
126
126
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
127
127
``"mean"``: the sum of the output will be divided by the number of
128
128
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
129
+ skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
130
+ valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
131
+ shape of the data and that masking the data results in a valid data structure. Among other things, this may
132
+ not be true in MARL settings or when using RNNs. Defaults to ``False``.
129
133
130
134
Examples:
131
135
>>> import torch
@@ -320,6 +324,7 @@ def __init__(
320
324
priority_key : str = None ,
321
325
separate_losses : bool = False ,
322
326
reduction : str = None ,
327
+ skip_done_states : bool = False ,
323
328
) -> None :
324
329
self ._in_keys = None
325
330
self ._out_keys = None
@@ -418,6 +423,7 @@ def __init__(
418
423
raise TypeError (_GAMMA_LMBDA_DEPREC_ERROR )
419
424
self ._make_vmap ()
420
425
self .reduction = reduction
426
+ self .skip_done_states = skip_done_states
421
427
422
428
def _make_vmap (self ):
423
429
self ._vmap_qnetworkN0 = _vmap_func (
@@ -712,36 +718,44 @@ def _compute_target_v2(self, tensordict) -> Tensor:
712
718
ExplorationType .RANDOM
713
719
), self .actor_network_params .to_module (self .actor_network ):
714
720
next_tensordict = tensordict .get ("next" ).copy ()
715
- # Check done state and avoid passing these to the actor
716
- done = next_tensordict .get (self .tensor_keys .done )
717
- if done is not None and done .any ():
718
- next_tensordict_select = next_tensordict [~ done .squeeze (- 1 )]
719
- else :
720
- next_tensordict_select = next_tensordict
721
- next_dist = self .actor_network .get_dist (next_tensordict_select )
722
- next_action = next_dist .rsample ()
723
- next_sample_log_prob = compute_log_prob (
724
- next_dist , next_action , self .tensor_keys .log_prob
725
- )
726
- if next_tensordict_select is not next_tensordict :
727
- mask = ~ done .squeeze (- 1 )
728
- if mask .ndim < next_action .ndim :
729
- mask = expand_right (
730
- mask , (* mask .shape , * next_action .shape [mask .ndim :])
731
- )
732
- next_action = next_action .new_zeros (mask .shape ).masked_scatter_ (
733
- mask , next_action
721
+ if self .skip_done_states :
722
+ # Check done state and avoid passing these to the actor
723
+ done = next_tensordict .get (self .tensor_keys .done )
724
+ if done is not None and done .any ():
725
+ next_tensordict_select = next_tensordict [~ done .squeeze (- 1 )]
726
+ else :
727
+ next_tensordict_select = next_tensordict
728
+ next_dist = self .actor_network .get_dist (next_tensordict_select )
729
+ next_action = next_dist .rsample ()
730
+ next_sample_log_prob = compute_log_prob (
731
+ next_dist , next_action , self .tensor_keys .log_prob
734
732
)
735
- mask = ~ done .squeeze (- 1 )
736
- if mask .ndim < next_sample_log_prob .ndim :
737
- mask = expand_right (
738
- mask ,
739
- (* mask .shape , * next_sample_log_prob .shape [mask .ndim :]),
733
+ if next_tensordict_select is not next_tensordict :
734
+ mask = ~ done .squeeze (- 1 )
735
+ if mask .ndim < next_action .ndim :
736
+ mask = expand_right (
737
+ mask , (* mask .shape , * next_action .shape [mask .ndim :])
738
+ )
739
+ next_action = next_action .new_zeros (mask .shape ).masked_scatter_ (
740
+ mask , next_action
740
741
)
741
- next_sample_log_prob = next_sample_log_prob .new_zeros (
742
- mask .shape
743
- ).masked_scatter_ (mask , next_sample_log_prob )
744
- next_tensordict .set (self .tensor_keys .action , next_action )
742
+ mask = ~ done .squeeze (- 1 )
743
+ if mask .ndim < next_sample_log_prob .ndim :
744
+ mask = expand_right (
745
+ mask ,
746
+ (* mask .shape , * next_sample_log_prob .shape [mask .ndim :]),
747
+ )
748
+ next_sample_log_prob = next_sample_log_prob .new_zeros (
749
+ mask .shape
750
+ ).masked_scatter_ (mask , next_sample_log_prob )
751
+ next_tensordict .set (self .tensor_keys .action , next_action )
752
+ else :
753
+ next_dist = self .actor_network .get_dist (next_tensordict )
754
+ next_action = next_dist .rsample ()
755
+ next_tensordict .set (self .tensor_keys .action , next_action )
756
+ next_sample_log_prob = compute_log_prob (
757
+ next_dist , next_action , self .tensor_keys .log_prob
758
+ )
745
759
746
760
# get q-values
747
761
next_tensordict_expand = self ._vmap_qnetworkN0 (
@@ -877,6 +891,10 @@ class DiscreteSACLoss(LossModule):
877
891
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
878
892
``"mean"``: the sum of the output will be divided by the number of
879
893
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
894
+ skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
895
+ valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
896
+ shape of the data and that masking the data results in a valid data structure. Among other things, this may
897
+ not be true in MARL settings or when using RNNs. Defaults to ``False``.
880
898
881
899
Examples:
882
900
>>> import torch
@@ -1051,6 +1069,7 @@ def __init__(
1051
1069
priority_key : str = None ,
1052
1070
separate_losses : bool = False ,
1053
1071
reduction : str = None ,
1072
+ skip_done_states : bool = False ,
1054
1073
):
1055
1074
if reduction is None :
1056
1075
reduction = "mean"
@@ -1133,6 +1152,7 @@ def __init__(
1133
1152
)
1134
1153
self ._make_vmap ()
1135
1154
self .reduction = reduction
1155
+ self .skip_done_states = skip_done_states
1136
1156
1137
1157
def _make_vmap (self ):
1138
1158
self ._vmap_qnetworkN0 = _vmap_func (
@@ -1218,35 +1238,58 @@ def _compute_target(self, tensordict) -> Tensor:
1218
1238
with torch .no_grad ():
1219
1239
next_tensordict = tensordict .get ("next" ).clone (False )
1220
1240
1221
- done = next_tensordict .get (self .tensor_keys .done )
1222
- if done is not None and done .any ():
1223
- next_tensordict_select = next_tensordict [~ done .squeeze (- 1 )]
1224
- else :
1225
- next_tensordict_select = next_tensordict
1241
+ if self .skip_done_states :
1242
+ done = next_tensordict .get (self .tensor_keys .done )
1243
+ if done is not None and done .any ():
1244
+ next_tensordict_select = next_tensordict [~ done .squeeze (- 1 )]
1245
+ else :
1246
+ next_tensordict_select = next_tensordict
1226
1247
1227
- # get probs and log probs for actions computed from "next"
1228
- with self .actor_network_params .to_module (self .actor_network ):
1229
- next_dist = self .actor_network .get_dist (next_tensordict_select )
1230
- next_log_prob = next_dist .logits
1231
- next_prob = next_log_prob .exp ()
1248
+ # get probs and log probs for actions computed from "next"
1249
+ with self .actor_network_params .to_module (self .actor_network ):
1250
+ next_dist = self .actor_network .get_dist (next_tensordict_select )
1251
+ next_log_prob = next_dist .logits
1252
+ next_prob = next_log_prob .exp ()
1232
1253
1233
- # get q-values for all actions
1234
- next_tensordict_expand = self ._vmap_qnetworkN0 (
1235
- next_tensordict_select , self .target_qvalue_network_params
1236
- )
1237
- next_action_value = next_tensordict_expand .get (
1238
- self .tensor_keys .action_value
1239
- )
1254
+ # get q-values for all actions
1255
+ next_tensordict_expand = self ._vmap_qnetworkN0 (
1256
+ next_tensordict_select , self .target_qvalue_network_params
1257
+ )
1258
+ next_action_value = next_tensordict_expand .get (
1259
+ self .tensor_keys .action_value
1260
+ )
1240
1261
1241
- # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1242
- next_state_value = next_action_value .min (0 )[0 ] - self ._alpha * next_log_prob
1243
- # unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1244
- next_state_value = (next_prob * next_state_value ).sum (- 1 ).unsqueeze (- 1 )
1245
- if next_tensordict_select is not next_tensordict :
1246
- mask = ~ done
1247
- next_state_value = next_state_value .new_zeros (
1248
- mask .shape
1249
- ).masked_scatter_ (mask , next_state_value )
1262
+ # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1263
+ next_state_value = (
1264
+ next_action_value .min (0 )[0 ] - self ._alpha * next_log_prob
1265
+ )
1266
+ # unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1267
+ next_state_value = (next_prob * next_state_value ).sum (- 1 ).unsqueeze (- 1 )
1268
+ if next_tensordict_select is not next_tensordict :
1269
+ mask = ~ done
1270
+ next_state_value = next_state_value .new_zeros (
1271
+ mask .shape
1272
+ ).masked_scatter_ (mask , next_state_value )
1273
+ else :
1274
+ # get probs and log probs for actions computed from "next"
1275
+ with self .actor_network_params .to_module (self .actor_network ):
1276
+ next_dist = self .actor_network .get_dist (next_tensordict )
1277
+ next_prob = next_dist .probs
1278
+ next_log_prob = torch .log (torch .where (next_prob == 0 , 1e-8 , next_prob ))
1279
+
1280
+ # get q-values for all actions
1281
+ next_tensordict_expand = self ._vmap_qnetworkN0 (
1282
+ next_tensordict , self .target_qvalue_network_params
1283
+ )
1284
+ next_action_value = next_tensordict_expand .get (
1285
+ self .tensor_keys .action_value
1286
+ )
1287
+ # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1288
+ next_state_value = (
1289
+ next_action_value .min (0 )[0 ] - self ._alpha * next_log_prob
1290
+ )
1291
+ # unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1292
+ next_state_value = (next_prob * next_state_value ).sum (- 1 ).unsqueeze (- 1 )
1250
1293
1251
1294
tensordict .set (
1252
1295
("next" , self .value_estimator .tensor_keys .value ), next_state_value
0 commit comments