Skip to content

Commit 21eeca4

Browse files
author
Vincent Moens
committed
[BugFix] Avoid KeyError in slice sampler (for compile)
ghstack-source-id: 6e2a303 Pull Request resolved: #2670
1 parent f4709c1 commit 21eeca4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torchrl/data/replay_buffers/samplers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,13 +1485,13 @@ def _get_index(
14851485
truncated[seq_length.cumsum(0) - 1] = 1
14861486
index = index.to(torch.long).unbind(-1)
14871487
st_index = storage[index]
1488-
try:
1489-
done = st_index[done_key] | truncated
1490-
except KeyError:
1488+
done = st_index.get(done_key, default=None)
1489+
if done is None:
14911490
done = truncated.clone()
1492-
try:
1493-
terminated = st_index[terminated_key]
1494-
except KeyError:
1491+
else:
1492+
done = done | truncated
1493+
terminated = st_index.get(terminated_key, default=None)
1494+
if terminated is None:
14951495
terminated = torch.zeros_like(truncated)
14961496
return index, {
14971497
truncated_key: truncated,

0 commit comments

Comments
 (0)