Skip to content

Commit de61e4d

Browse files
author
Vincent Moens
committed
[BugFix] skip_done_states in SAC
ghstack-source-id: 39d9736 Pull Request resolved: #2613
1 parent 90c8e40 commit de61e4d

File tree

2 files changed

+99
-54
lines changed

2 files changed

+99
-54
lines changed

test/test_cost.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4493,6 +4493,7 @@ def test_sac_terminating(
44934493
actor_network=actor,
44944494
qvalue_network=qvalue,
44954495
value_network=value,
4496+
skip_done_states=True,
44964497
)
44974498
loss.set_keys(
44984499
action=action_key,
@@ -5204,6 +5205,7 @@ def test_discrete_sac_terminating(
52045205
qvalue_network=qvalue,
52055206
num_actions=actor.spec[action_key].space.n,
52065207
action_space="one-hot",
5208+
skip_done_states=True,
52075209
)
52085210
loss.set_keys(
52095211
action=action_key,

torchrl/objectives/sac.py

Lines changed: 97 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ class SACLoss(LossModule):
126126
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
127127
``"mean"``: the sum of the output will be divided by the number of
128128
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
129+
skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
130+
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
131+
shape of the data and that masking the data results in a valid data structure. Among other things, this may
132+
not be true in MARL settings or when using RNNs. Defaults to ``False``.
129133
130134
Examples:
131135
>>> import torch
@@ -320,6 +324,7 @@ def __init__(
320324
priority_key: str = None,
321325
separate_losses: bool = False,
322326
reduction: str = None,
327+
skip_done_states: bool = False,
323328
) -> None:
324329
self._in_keys = None
325330
self._out_keys = None
@@ -418,6 +423,7 @@ def __init__(
418423
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
419424
self._make_vmap()
420425
self.reduction = reduction
426+
self.skip_done_states = skip_done_states
421427

422428
def _make_vmap(self):
423429
self._vmap_qnetworkN0 = _vmap_func(
@@ -712,36 +718,44 @@ def _compute_target_v2(self, tensordict) -> Tensor:
712718
ExplorationType.RANDOM
713719
), self.actor_network_params.to_module(self.actor_network):
714720
next_tensordict = tensordict.get("next").copy()
715-
# Check done state and avoid passing these to the actor
716-
done = next_tensordict.get(self.tensor_keys.done)
717-
if done is not None and done.any():
718-
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
719-
else:
720-
next_tensordict_select = next_tensordict
721-
next_dist = self.actor_network.get_dist(next_tensordict_select)
722-
next_action = next_dist.rsample()
723-
next_sample_log_prob = compute_log_prob(
724-
next_dist, next_action, self.tensor_keys.log_prob
725-
)
726-
if next_tensordict_select is not next_tensordict:
727-
mask = ~done.squeeze(-1)
728-
if mask.ndim < next_action.ndim:
729-
mask = expand_right(
730-
mask, (*mask.shape, *next_action.shape[mask.ndim :])
731-
)
732-
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
733-
mask, next_action
721+
if self.skip_done_states:
722+
# Check done state and avoid passing these to the actor
723+
done = next_tensordict.get(self.tensor_keys.done)
724+
if done is not None and done.any():
725+
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
726+
else:
727+
next_tensordict_select = next_tensordict
728+
next_dist = self.actor_network.get_dist(next_tensordict_select)
729+
next_action = next_dist.rsample()
730+
next_sample_log_prob = compute_log_prob(
731+
next_dist, next_action, self.tensor_keys.log_prob
734732
)
735-
mask = ~done.squeeze(-1)
736-
if mask.ndim < next_sample_log_prob.ndim:
737-
mask = expand_right(
738-
mask,
739-
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
733+
if next_tensordict_select is not next_tensordict:
734+
mask = ~done.squeeze(-1)
735+
if mask.ndim < next_action.ndim:
736+
mask = expand_right(
737+
mask, (*mask.shape, *next_action.shape[mask.ndim :])
738+
)
739+
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
740+
mask, next_action
740741
)
741-
next_sample_log_prob = next_sample_log_prob.new_zeros(
742-
mask.shape
743-
).masked_scatter_(mask, next_sample_log_prob)
744-
next_tensordict.set(self.tensor_keys.action, next_action)
742+
mask = ~done.squeeze(-1)
743+
if mask.ndim < next_sample_log_prob.ndim:
744+
mask = expand_right(
745+
mask,
746+
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
747+
)
748+
next_sample_log_prob = next_sample_log_prob.new_zeros(
749+
mask.shape
750+
).masked_scatter_(mask, next_sample_log_prob)
751+
next_tensordict.set(self.tensor_keys.action, next_action)
752+
else:
753+
next_dist = self.actor_network.get_dist(next_tensordict)
754+
next_action = next_dist.rsample()
755+
next_tensordict.set(self.tensor_keys.action, next_action)
756+
next_sample_log_prob = compute_log_prob(
757+
next_dist, next_action, self.tensor_keys.log_prob
758+
)
745759

746760
# get q-values
747761
next_tensordict_expand = self._vmap_qnetworkN0(
@@ -877,6 +891,10 @@ class DiscreteSACLoss(LossModule):
877891
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
878892
``"mean"``: the sum of the output will be divided by the number of
879893
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
894+
skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
895+
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
896+
shape of the data and that masking the data results in a valid data structure. Among other things, this may
897+
not be true in MARL settings or when using RNNs. Defaults to ``False``.
880898
881899
Examples:
882900
>>> import torch
@@ -1051,6 +1069,7 @@ def __init__(
10511069
priority_key: str = None,
10521070
separate_losses: bool = False,
10531071
reduction: str = None,
1072+
skip_done_states: bool = False,
10541073
):
10551074
if reduction is None:
10561075
reduction = "mean"
@@ -1133,6 +1152,7 @@ def __init__(
11331152
)
11341153
self._make_vmap()
11351154
self.reduction = reduction
1155+
self.skip_done_states = skip_done_states
11361156

11371157
def _make_vmap(self):
11381158
self._vmap_qnetworkN0 = _vmap_func(
@@ -1218,35 +1238,58 @@ def _compute_target(self, tensordict) -> Tensor:
12181238
with torch.no_grad():
12191239
next_tensordict = tensordict.get("next").clone(False)
12201240

1221-
done = next_tensordict.get(self.tensor_keys.done)
1222-
if done is not None and done.any():
1223-
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
1224-
else:
1225-
next_tensordict_select = next_tensordict
1241+
if self.skip_done_states:
1242+
done = next_tensordict.get(self.tensor_keys.done)
1243+
if done is not None and done.any():
1244+
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
1245+
else:
1246+
next_tensordict_select = next_tensordict
12261247

1227-
# get probs and log probs for actions computed from "next"
1228-
with self.actor_network_params.to_module(self.actor_network):
1229-
next_dist = self.actor_network.get_dist(next_tensordict_select)
1230-
next_log_prob = next_dist.logits
1231-
next_prob = next_log_prob.exp()
1248+
# get probs and log probs for actions computed from "next"
1249+
with self.actor_network_params.to_module(self.actor_network):
1250+
next_dist = self.actor_network.get_dist(next_tensordict_select)
1251+
next_log_prob = next_dist.logits
1252+
next_prob = next_log_prob.exp()
12321253

1233-
# get q-values for all actions
1234-
next_tensordict_expand = self._vmap_qnetworkN0(
1235-
next_tensordict_select, self.target_qvalue_network_params
1236-
)
1237-
next_action_value = next_tensordict_expand.get(
1238-
self.tensor_keys.action_value
1239-
)
1254+
# get q-values for all actions
1255+
next_tensordict_expand = self._vmap_qnetworkN0(
1256+
next_tensordict_select, self.target_qvalue_network_params
1257+
)
1258+
next_action_value = next_tensordict_expand.get(
1259+
self.tensor_keys.action_value
1260+
)
12401261

1241-
# like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1242-
next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob
1243-
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1244-
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
1245-
if next_tensordict_select is not next_tensordict:
1246-
mask = ~done
1247-
next_state_value = next_state_value.new_zeros(
1248-
mask.shape
1249-
).masked_scatter_(mask, next_state_value)
1262+
# like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1263+
next_state_value = (
1264+
next_action_value.min(0)[0] - self._alpha * next_log_prob
1265+
)
1266+
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1267+
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
1268+
if next_tensordict_select is not next_tensordict:
1269+
mask = ~done
1270+
next_state_value = next_state_value.new_zeros(
1271+
mask.shape
1272+
).masked_scatter_(mask, next_state_value)
1273+
else:
1274+
# get probs and log probs for actions computed from "next"
1275+
with self.actor_network_params.to_module(self.actor_network):
1276+
next_dist = self.actor_network.get_dist(next_tensordict)
1277+
next_prob = next_dist.probs
1278+
next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))
1279+
1280+
# get q-values for all actions
1281+
next_tensordict_expand = self._vmap_qnetworkN0(
1282+
next_tensordict, self.target_qvalue_network_params
1283+
)
1284+
next_action_value = next_tensordict_expand.get(
1285+
self.tensor_keys.action_value
1286+
)
1287+
# like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1288+
next_state_value = (
1289+
next_action_value.min(0)[0] - self._alpha * next_log_prob
1290+
)
1291+
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1292+
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
12501293

12511294
tensordict.set(
12521295
("next", self.value_estimator.tensor_keys.value), next_state_value

0 commit comments

Comments
 (0)