Skip to content

Commit f121f4d

Browse files
author
Vincent Moens
committed
[BugFix] Enable ndim done states in GAE with shifted=True
ghstack-source-id: fb9fd48 Pull-Request-resolved: #2962
1 parent 8edc29c commit f121f4d

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

test/test_cost.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14247,6 +14247,46 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1424714247

1424814248

1424914249
class TestValues:
14250+
def test_gae_multi_done(self):
14251+
14252+
# constants
14253+
batch_size = 10
14254+
seq_size = 5
14255+
n_dims = batch_size
14256+
gamma = 0.99
14257+
lmbda = 0.98
14258+
14259+
env = SerialEnv(
14260+
batch_size, [functools.partial(GymEnv, "CartPole-v1")] * batch_size
14261+
)
14262+
obs_size = env.full_observation_spec[env.observation_keys[0]].shape[-1]
14263+
14264+
td = env.rollout(seq_size, break_when_any_done=False)
14265+
# make the magic happen: swap dims and create an artificial ndim done state
14266+
done = td["next", "done"].transpose(0, -1)
14267+
terminated = td["next", "terminated"].transpose(0, -1)
14268+
reward = td["next", "reward"].transpose(0, -1)
14269+
td = td[:1]
14270+
td["next", "done"] = done
14271+
td["next", "terminated"] = terminated
14272+
td["next", "reward"] = reward
14273+
14274+
critic = TensorDictModule(
14275+
nn.Linear(obs_size, n_dims),
14276+
in_keys=[("observation",)],
14277+
out_keys=[("state_value",)],
14278+
)
14279+
14280+
gae_shifted = GAE(gamma=gamma, lmbda=lmbda, value_network=critic, shifted=True)
14281+
gae_no_shifted = GAE(
14282+
gamma=gamma, lmbda=lmbda, value_network=critic, shifted=False
14283+
)
14284+
14285+
torch.testing.assert_close(
14286+
gae_shifted(td.clone())["advantage"],
14287+
gae_no_shifted(td.clone())["advantage"],
14288+
)
14289+
1425014290
@pytest.mark.skipif(not _has_gym, reason="requires gym")
1425114291
@pytest.mark.parametrize("module", ["lstm", "gru"])
1425214292
def test_gae_recurrent(self, module):

torchrl/objectives/value/advantages.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,14 @@ def _call_value_nets(
460460
data_copy = data.copy()
461461
# we are going to modify the done so let's clone it
462462
done = data_copy["next", "done"].clone()
463-
464463
# Mark the last step of every sequence as done. We do this because flattening would cause the trajectories
465464
# of different batches to be merged.
466465
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
466+
truncated = data_copy.get(("next", "truncated"), done)
467+
if truncated is not done:
468+
truncated[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
467469
data_copy["next", "done"] = done
470+
data_copy["next", "truncated"] = truncated
468471
# Reshape to -1 because we cannot guarantee that all dims have the same number of done states
469472
with data_copy.view(-1) as data_copy_view:
470473
# Interleave next data when done
@@ -482,7 +485,11 @@ def _call_value_nets(
482485
# done = [0, 0, 1, 0, 1, 0, 1]
483486
# done_cs = [0, 0, 0, 1, 1, 2, 2]
484487
# indices = [0, 1, 2, 4, 5, 7, 8]
485-
done_view = data_copy_view["next", "done"].squeeze(-1)
488+
done_view = data_copy_view["next", "done"]
489+
if done_view.shape[-1] == 1:
490+
done_view = done_view.squeeze(-1)
491+
else:
492+
done_view = done_view.any(-1)
486493
done_cs = done_view.cumsum(0)
487494
done_cs = torch.cat([done_cs.new_zeros((1,)), done_cs[:-1]], dim=0)
488495
indices = torch.arange(done_cs.shape[0], device=done_cs.device)

0 commit comments

Comments
 (0)