Skip to content

Commit 75f113f

Browse files
author
Vincent Moens
authored
[Doc] Fix tutorials (#2768)
1 parent 85d1e70 commit 75f113f

File tree

5 files changed

+11
-15
lines changed

5 files changed

+11
-15
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
build-docs:
2727
strategy:
2828
matrix:
29-
python_version: ["3.10"]
29+
python_version: ["3.9"]
3030
cuda_arch_version: ["12.4"]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
@@ -60,7 +60,7 @@ jobs:
6060
bash ./miniconda.sh -b -f -p "${conda_dir}"
6161
eval "$(${conda_dir}/bin/conda shell.bash hook)"
6262
printf "* Creating a test environment\n"
63-
conda create --prefix "${env_dir}" -y python=3.10
63+
conda create --prefix "${env_dir}" -y python=3.9
6464
printf "* Activating\n"
6565
conda activate "${env_dir}"
6666

docs/source/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@
9494
"filename_pattern": "reference/generated/tutorials/", # files to parse
9595
"notebook_images": "reference/generated/tutorials/media/", # images to parse
9696
"download_all_examples": True,
97-
"abort_on_example_error": False,
98-
"only_warn_on_example_error": True,
97+
"abort_on_example_error": True,
9998
"show_memory": True,
10099
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
101100
}

torchrl/trainers/trainers.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ class ReplayBufferTrainer(TrainerHookBase):
640640
memmap (bool, optional): if ``True``, a memmap tensordict is created.
641641
Default is ``False``.
642642
device (device, optional): device where the samples must be placed.
643-
Default is ``cpu``.
643+
Default to ``None``.
644644
flatten_tensordicts (bool, optional): if ``True``, the tensordicts will be
645645
flattened (or equivalently masked with the valid mask obtained from
646646
the collector) before being passed to the replay buffer. Otherwise,
@@ -666,7 +666,7 @@ def __init__(
666666
replay_buffer: TensorDictReplayBuffer,
667667
batch_size: Optional[int] = None,
668668
memmap: bool = False,
669-
device: DEVICE_TYPING = "cpu",
669+
device: DEVICE_TYPING | None = None,
670670
flatten_tensordicts: bool = False,
671671
max_dims: Optional[Sequence[int]] = None,
672672
) -> None:
@@ -695,15 +695,11 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase:
695695
pads += [0, pad_value]
696696
batch = pad(batch, pads)
697697
batch = batch.cpu()
698-
if self.memmap:
699-
# We can already place the tensords on the device if they're memmap,
700-
# as this is a lazy op
701-
batch = batch.memmap_().to(self.device)
702698
self.replay_buffer.extend(batch)
703699

704700
def sample(self, batch: TensorDictBase) -> TensorDictBase:
705701
sample = self.replay_buffer.sample(batch_size=self.batch_size)
706-
return sample.to(self.device, non_blocking=True)
702+
return sample.to(self.device) if self.device is not None else sample
707703

708704
def update_priority(self, batch: TensorDictBase) -> None:
709705
self.replay_buffer.update_tensordict_priority(batch)

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,12 @@ def make_model(dummy_env):
380380
# time must always have the same shape.
381381

382382

383-
def get_replay_buffer(buffer_size, n_optim, batch_size):
383+
def get_replay_buffer(buffer_size, n_optim, batch_size, device):
384384
replay_buffer = TensorDictReplayBuffer(
385385
batch_size=batch_size,
386386
storage=LazyMemmapStorage(buffer_size),
387387
prefetch=n_optim,
388+
transform=lambda td: td.to(device),
388389
)
389390
return replay_buffer
390391

@@ -660,7 +661,7 @@ def get_loss_module(actor, gamma):
660661
# requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which
661662
# can be cumbersome to implement.
662663
buffer_hook = ReplayBufferTrainer(
663-
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
664+
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size, device=device),
664665
flatten_tensordicts=True,
665666
)
666667
buffer_hook.register(trainer)

tutorials/sphinx-tutorials/pretrained_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import torch.cuda
3838
from tensordict.nn import TensorDictSequential
3939
from torch import nn
40-
from torchrl.envs import R3MTransform, TransformedEnv
40+
from torchrl.envs import Compose, R3MTransform, TransformedEnv
4141
from torchrl.envs.libs.gym import GymEnv
4242
from torchrl.modules import Actor
4343

@@ -115,7 +115,7 @@
115115
from torchrl.data import LazyMemmapStorage, ReplayBuffer
116116

117117
storage = LazyMemmapStorage(1000)
118-
rb = ReplayBuffer(storage=storage, transform=r3m)
118+
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))
119119

120120
##############################################################################
121121
# We can now collect the data (random rollouts for our purpose) and fill the replay

0 commit comments

Comments
 (0)