Skip to content

Commit 6277226

Browse files
author
Vincent Moens
authored
[Doc] Make tutos runnable without colab (#1826)
1 parent 79374d8 commit 6277226

File tree

12 files changed

+156
-11
lines changed

12 files changed

+156
-11
lines changed

torchrl/objectives/ppo.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,41 +345,46 @@ def functional(self):
345345

346346
@property
347347
def actor(self):
348-
logging.warning(
348+
warnings.warn(
349349
f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This "
350-
"link will be removed in v0.4."
350+
"link will be removed in v0.4.",
351+
category=DeprecationWarning,
351352
)
352353
return self.actor_network
353354

354355
@property
355356
def critic(self):
356-
logging.warning(
357+
warnings.warn(
357358
f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This "
358-
"link will be removed in v0.4."
359+
"link will be removed in v0.4.",
360+
category=DeprecationWarning,
359361
)
360362
return self.critic_network
361363

362364
@property
363365
def actor_params(self):
364366
logging.warning(
365367
f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This "
366-
"link will be removed in v0.4."
368+
"link will be removed in v0.4.",
369+
category=DeprecationWarning,
367370
)
368371
return self.actor_network_params
369372

370373
@property
371374
def critic_params(self):
372-
logging.warning(
375+
warnings.warn(
373376
f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This "
374-
"link will be removed in v0.4."
377+
"link will be removed in v0.4.",
378+
category=DeprecationWarning,
375379
)
376380
return self.critic_network_params
377381

378382
@property
379383
def target_critic_params(self):
380-
logging.warning(
384+
warnings.warn(
381385
f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This "
382-
"link will be removed in v0.4."
386+
"link will be removed in v0.4.",
387+
category=DeprecationWarning,
383388
)
384389
return self.target_critic_network_params
385390

tutorials/sphinx-tutorials/coding_ddpg.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@
5757
from typing import Tuple
5858

5959
warnings.filterwarnings("ignore")
60+
from torch import multiprocessing
61+
62+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
63+
# `__main__` method call, but for the easy of reading the code switch to fork
64+
# which is also a default spawn method in Google's Colaboratory
65+
try:
66+
multiprocessing.set_start_method("fork")
67+
except RuntimeError:
68+
assert multiprocessing.get_start_method() == "fork"
69+
6070
# sphinx_gallery_end_ignore
6171

6272
import torch.cuda

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@
8787
import warnings
8888

8989
warnings.filterwarnings("ignore")
90+
91+
from torch import multiprocessing
92+
93+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
94+
# `__main__` method call, but for the easy of reading the code switch to fork
95+
# which is also a default spawn method in Google's Colaboratory
96+
try:
97+
multiprocessing.set_start_method("fork")
98+
except RuntimeError:
99+
assert multiprocessing.get_start_method() == "fork"
100+
101+
90102
# sphinx_gallery_end_ignore
91103

92104
import os

tutorials/sphinx-tutorials/coding_ppo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@
104104
# description and more about the algorithm itself.
105105
#
106106

107+
# sphinx_gallery_start_ignore
108+
import warnings
109+
110+
warnings.filterwarnings("ignore")
111+
from torch import multiprocessing
112+
113+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
114+
# `__main__` method call, but for the easy of reading the code switch to fork
115+
# which is also a default spawn method in Google's Colaboratory
116+
try:
117+
multiprocessing.set_start_method("fork")
118+
except RuntimeError:
119+
assert multiprocessing.get_start_method() == "fork"
120+
121+
# sphinx_gallery_end_ignore
122+
107123
from collections import defaultdict
108124

109125
import matplotlib.pyplot as plt

tutorials/sphinx-tutorials/dqn_with_rnn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@
6868
# -----
6969
#
7070

71+
# sphinx_gallery_start_ignore
72+
import warnings
73+
74+
warnings.filterwarnings("ignore")
75+
from torch import multiprocessing
76+
77+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
78+
# `__main__` method call, but for the easy of reading the code switch to fork
79+
# which is also a default spawn method in Google's Colaboratory
80+
try:
81+
multiprocessing.set_start_method("fork")
82+
except RuntimeError:
83+
assert multiprocessing.get_start_method() == "fork"
84+
85+
# sphinx_gallery_end_ignore
86+
7187
import torch
7288
import tqdm
7389
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

tutorials/sphinx-tutorials/multi_task.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
import warnings
1414

1515
warnings.filterwarnings("ignore")
16+
17+
from torch import multiprocessing
18+
19+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
20+
# `__main__` method call, but for the easy of reading the code switch to fork
21+
# which is also a default spawn method in Google's Colaboratory
22+
try:
23+
multiprocessing.set_start_method("fork")
24+
except RuntimeError:
25+
assert multiprocessing.get_start_method() == "fork"
26+
1627
# sphinx_gallery_end_ignore
1728

1829
import torch

tutorials/sphinx-tutorials/multiagent_ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,8 @@
659659
with torch.no_grad():
660660
GAE(
661661
tensordict_data,
662-
params=loss_module.critic_params,
663-
target_params=loss_module.target_critic_params,
662+
params=loss_module.critic_network_params,
663+
target_params=loss_module.target_critic_network_params,
664664
) # Compute GAE and add it to the data
665665

666666
data_view = tensordict_data.reshape(-1) # Flatten the batch size to shuffle data

tutorials/sphinx-tutorials/pendulum.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@
7373
# simulation graph.
7474
# * Finally, we will train a simple policy to solve the system we implemented.
7575
#
76+
77+
# sphinx_gallery_start_ignore
78+
import warnings
79+
80+
warnings.filterwarnings("ignore")
81+
from torch import multiprocessing
82+
83+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
84+
# `__main__` method call, but for the easy of reading the code switch to fork
85+
# which is also a default spawn method in Google's Colaboratory
86+
try:
87+
multiprocessing.set_start_method("fork")
88+
except RuntimeError:
89+
assert multiprocessing.get_start_method() == "fork"
90+
91+
# sphinx_gallery_end_ignore
92+
7693
from collections import defaultdict
7794
from typing import Optional
7895

tutorials/sphinx-tutorials/pretrained_models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@
1313
# in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601),
1414
# but other models (e.g. VIP) will work equally well.
1515
#
16+
17+
# sphinx_gallery_start_ignore
18+
import warnings
19+
20+
warnings.filterwarnings("ignore")
21+
from torch import multiprocessing
22+
23+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
24+
# `__main__` method call, but for the easy of reading the code switch to fork
25+
# which is also a default spawn method in Google's Colaboratory
26+
try:
27+
multiprocessing.set_start_method("fork")
28+
except RuntimeError:
29+
assert multiprocessing.get_start_method() == "fork"
30+
31+
# sphinx_gallery_end_ignore
32+
1633
import torch.cuda
1734
from tensordict.nn import TensorDictSequential
1835
from torch import nn

tutorials/sphinx-tutorials/rb_tutorial.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,23 @@
4646
# replay buffer is a straightforward process, as shown in the following
4747
# example:
4848
#
49+
50+
# sphinx_gallery_start_ignore
51+
import warnings
52+
53+
warnings.filterwarnings("ignore")
54+
from torch import multiprocessing
55+
56+
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
57+
# `__main__` method call, but for the easy of reading the code switch to fork
58+
# which is also a default spawn method in Google's Colaboratory
59+
try:
60+
multiprocessing.set_start_method("fork")
61+
except RuntimeError:
62+
assert multiprocessing.get_start_method() == "fork"
63+
64+
# sphinx_gallery_end_ignore
65+
4966
import tempfile
5067

5168
from torchrl.data import ReplayBuffer

0 commit comments

Comments
 (0)