We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f4709c1 commit 21eeca4Copy full SHA for 21eeca4
torchrl/data/replay_buffers/samplers.py
@@ -1485,13 +1485,13 @@ def _get_index(
1485
truncated[seq_length.cumsum(0) - 1] = 1
1486
index = index.to(torch.long).unbind(-1)
1487
st_index = storage[index]
1488
- try:
1489
- done = st_index[done_key] | truncated
1490
- except KeyError:
+ done = st_index.get(done_key, default=None)
+ if done is None:
1491
done = truncated.clone()
1492
1493
- terminated = st_index[terminated_key]
1494
+ else:
+ done = done | truncated
+ terminated = st_index.get(terminated_key, default=None)
+ if terminated is None:
1495
terminated = torch.zeros_like(truncated)
1496
return index, {
1497
truncated_key: truncated,
0 commit comments