Skip to content

Commit 995c09f

Browse files
svekarsvmoens
andauthored
Fix DQN w RNN tutorial (#3462)
* Fix DQN w RNN tutorial * bump torchrl and tensordict req (#3474) --------- Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 967a266 commit 995c09f

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

.ci/docker/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ tensorboard
3838
jinja2==3.1.3
3939
pytorch-lightning
4040
torchx
41-
torchrl==0.7.2
42-
tensordict==0.7.2
41+
torchrl==0.9.2
42+
tensordict==0.9.1
4343
# For ax_multiobjective_nas_tutorial.py
4444
ax-platform>=0.4.0,<0.5.0
4545
nbformat>=5.9.2

intermediate_source/dqn_with_rnn_tutorial.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,9 @@
342342
# will return a new instance of the LSTM (with shared weights) that will
343343
# assume that the input data is sequential in nature.
344344
#
345-
policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)
345+
from torchrl.modules import set_recurrent_mode
346+
347+
policy = Seq(feature, lstm, mlp, qval)
346348

347349
######################################################################
348350
# Because we still have a couple of uninitialized parameters we should
@@ -389,7 +391,10 @@
389391
# For the sake of efficiency, we're only running a few thousands iterations
390392
# here. In a real setting, the total number of frames should be set to 1M.
391393
#
392-
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
394+
395+
collector = SyncDataCollector(
396+
env, stoch_policy, frames_per_batch=50, total_frames=200, device=device
397+
)
393398
rb = TensorDictReplayBuffer(
394399
storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
395400
)
@@ -422,7 +427,8 @@
422427
rb.extend(data.unsqueeze(0).to_tensordict().cpu())
423428
for _ in range(utd):
424429
s = rb.sample().to(device, non_blocking=True)
425-
loss_vals = loss_fn(s)
430+
with set_recurrent_mode(True):
431+
loss_vals = loss_fn(s)
426432
loss_vals["loss"].backward()
427433
optim.step()
428434
optim.zero_grad()
@@ -464,5 +470,5 @@
464470
#
465471
# Further Reading
466472
# ---------------
467-
#
473+
#
468474
# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.

0 commit comments

Comments
 (0)