Skip to content

Commit d5d787a

Browse files
author
Vincent Moens
committed
[Doc] Fix tutorials (#2768)
(cherry picked from commit 75f113f)
1 parent 3ff13ff commit d5d787a

File tree

7 files changed

+24
-15
lines changed

7 files changed

+24
-15
lines changed

.github/workflows/docs.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
build-docs:
2727
strategy:
2828
matrix:
29-
python_version: ["3.10"]
29+
python_version: ["3.9"]
3030
cuda_arch_version: ["12.4"]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
@@ -60,7 +60,7 @@ jobs:
6060
bash ./miniconda.sh -b -f -p "${conda_dir}"
6161
eval "$(${conda_dir}/bin/conda shell.bash hook)"
6262
printf "* Creating a test environment\n"
63-
conda create --prefix "${env_dir}" -y python=3.10
63+
conda create --prefix "${env_dir}" -y python=3.9
6464
printf "* Activating\n"
6565
conda activate "${env_dir}"
6666
@@ -107,6 +107,7 @@ jobs:
107107
cd ..
108108
109109
# 11. Build doc
110+
export MAX_IDLE_COUNT=180 # Max 180 secs before killing an unresponsive collector
110111
cd ./docs
111112
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
112113
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi

docs/source/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@
9494
"filename_pattern": "reference/generated/tutorials/", # files to parse
9595
"notebook_images": "reference/generated/tutorials/media/", # images to parse
9696
"download_all_examples": True,
97-
"abort_on_example_error": False,
98-
"only_warn_on_example_error": True,
97+
"abort_on_example_error": True,
9998
"show_memory": True,
10099
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
100+
"write_computation_times": True,
101101
}
102102

103103
napoleon_use_ivar = True

torchrl/envs/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,8 @@ def check_env_specs(
819819
spec = Composite(shape=env.batch_size, device=env.device)
820820
td = last_td.select(*spec.keys(True, True), strict=True)
821821
if not spec.contains(td):
822+
for k, v in spec.items(True):
823+
assert v.contains(td[k]), f"{k} is not in {v} (val: {td[k]})"
822824
raise AssertionError(
823825
f"spec check failed at root for spec {name}={spec} and data {td}."
824826
)

torchrl/trainers/trainers.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ class ReplayBufferTrainer(TrainerHookBase):
640640
memmap (bool, optional): if ``True``, a memmap tensordict is created.
641641
Default is ``False``.
642642
device (device, optional): device where the samples must be placed.
643-
Default is ``cpu``.
643+
Default to ``None``.
644644
flatten_tensordicts (bool, optional): if ``True``, the tensordicts will be
645645
flattened (or equivalently masked with the valid mask obtained from
646646
the collector) before being passed to the replay buffer. Otherwise,
@@ -666,7 +666,7 @@ def __init__(
666666
replay_buffer: TensorDictReplayBuffer,
667667
batch_size: Optional[int] = None,
668668
memmap: bool = False,
669-
device: DEVICE_TYPING = "cpu",
669+
device: DEVICE_TYPING | None = None,
670670
flatten_tensordicts: bool = False,
671671
max_dims: Optional[Sequence[int]] = None,
672672
) -> None:
@@ -695,15 +695,11 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase:
695695
pads += [0, pad_value]
696696
batch = pad(batch, pads)
697697
batch = batch.cpu()
698-
if self.memmap:
699-
# We can already place the tensords on the device if they're memmap,
700-
# as this is a lazy op
701-
batch = batch.memmap_().to(self.device)
702698
self.replay_buffer.extend(batch)
703699

704700
def sample(self, batch: TensorDictBase) -> TensorDictBase:
705701
sample = self.replay_buffer.sample(batch_size=self.batch_size)
706-
return sample.to(self.device, non_blocking=True)
702+
return sample.to(self.device) if self.device is not None else sample
707703

708704
def update_priority(self, batch: TensorDictBase) -> None:
709705
self.replay_buffer.update_tensordict_priority(batch)

tutorials/sphinx-tutorials/coding_ddpg.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,12 @@ def ceil_div(x, y):
11851185
collector.shutdown()
11861186
del collector
11871187

1188+
try:
1189+
parallel_env.close()
1190+
del parallel_env
1191+
except Exception:
1192+
pass
1193+
11881194
###############################################################################
11891195
# Experiment results
11901196
# ------------------

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,12 @@ def make_model(dummy_env):
380380
# time must always have the same shape.
381381

382382

383-
def get_replay_buffer(buffer_size, n_optim, batch_size):
383+
def get_replay_buffer(buffer_size, n_optim, batch_size, device):
384384
replay_buffer = TensorDictReplayBuffer(
385385
batch_size=batch_size,
386386
storage=LazyMemmapStorage(buffer_size),
387387
prefetch=n_optim,
388+
transform=lambda td: td.to(device),
388389
)
389390
return replay_buffer
390391

@@ -660,7 +661,7 @@ def get_loss_module(actor, gamma):
660661
# requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which
661662
# can be cumbersome to implement.
662663
buffer_hook = ReplayBufferTrainer(
663-
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
664+
get_replay_buffer(buffer_size, n_optim, batch_size=batch_size, device=device),
664665
flatten_tensordicts=True,
665666
)
666667
buffer_hook.register(trainer)
@@ -750,6 +751,9 @@ def print_csv_files_in_folder(folder_path):
750751

751752
print_csv_files_in_folder(logger.experiment.log_dir)
752753

754+
trainer.shutdown()
755+
del trainer
756+
753757
###############################################################################
754758
# Conclusion and possible improvements
755759
# ------------------------------------

tutorials/sphinx-tutorials/pretrained_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import torch.cuda
3838
from tensordict.nn import TensorDictSequential
3939
from torch import nn
40-
from torchrl.envs import R3MTransform, TransformedEnv
40+
from torchrl.envs import Compose, R3MTransform, TransformedEnv
4141
from torchrl.envs.libs.gym import GymEnv
4242
from torchrl.modules import Actor
4343

@@ -115,7 +115,7 @@
115115
from torchrl.data import LazyMemmapStorage, ReplayBuffer
116116

117117
storage = LazyMemmapStorage(1000)
118-
rb = ReplayBuffer(storage=storage, transform=r3m)
118+
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))
119119

120120
##############################################################################
121121
# We can now collect the data (random rollouts for our purpose) and fill the replay

0 commit comments

Comments
 (0)