|
342 | 342 | # will return a new instance of the LSTM (with shared weights) that will
|
343 | 343 | # assume that the input data is sequential in nature.
|
344 | 344 | #
|
345 |
| -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) |
| 345 | +from torchrl.modules import set_recurrent_mode |
| 346 | + |
| 347 | +policy = Seq(feature, lstm, mlp, qval) |
346 | 348 |
|
347 | 349 | ######################################################################
|
348 | 350 | # Because we still have a couple of uninitialized parameters we should
|
|
389 | 391 | # For the sake of efficiency, we're only running a few thousands iterations
|
390 | 392 | # here. In a real setting, the total number of frames should be set to 1M.
|
391 | 393 | #
|
392 |
| -collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device) |
| 394 | + |
| 395 | +collector = SyncDataCollector( |
| 396 | + env, stoch_policy, frames_per_batch=50, total_frames=200, device=device |
| 397 | +) |
393 | 398 | rb = TensorDictReplayBuffer(
|
394 | 399 | storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
|
395 | 400 | )
|
|
422 | 427 | rb.extend(data.unsqueeze(0).to_tensordict().cpu())
|
423 | 428 | for _ in range(utd):
|
424 | 429 | s = rb.sample().to(device, non_blocking=True)
|
425 |
| - loss_vals = loss_fn(s) |
| 430 | + with set_recurrent_mode(True): |
| 431 | + loss_vals = loss_fn(s) |
426 | 432 | loss_vals["loss"].backward()
|
427 | 433 | optim.step()
|
428 | 434 | optim.zero_grad()
|
|
464 | 470 | #
|
465 | 471 | # Further Reading
|
466 | 472 | # ---------------
|
467 |
| -# |
| 473 | +# |
468 | 474 | # - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.
|
0 commit comments