Skip to content

Commit 7051238

Browse files
author
Vincent Moens
committed
[Doc] Adding recurrent policies to export tutorial
ghstack-source-id: 1f1af39 Pull Request resolved: #2559
1 parent c0187a9 commit 7051238

File tree

6 files changed

+328
-47
lines changed

6 files changed

+328
-47
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ jobs:
9393
cd ./docs
9494
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
9595
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
96-
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build
96+
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
9797
cd ..
9898
9999
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"

docs/source/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
version = f"main ({torchrl.__version__})"
5050
release = "main"
5151

52+
os.environ["TORCHRL_CONSOLE_STREAM"] = "stdout"
53+
5254
# The language for content autogenerated by Sphinx. Refer to documentation
5355
# for a list of supported languages.
5456
#
@@ -95,6 +97,7 @@
9597
"abort_on_example_error": False,
9698
"only_warn_on_example_error": True,
9799
"show_memory": True,
100+
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
98101
}
99102

100103
napoleon_use_ivar = True

test/test_tensordictmodules.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import pytest
1010
import torch
11+
12+
import torchrl.modules
1113
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
1214
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
1315
from torch import nn
@@ -743,6 +745,41 @@ def test_set_temporal_mode(self):
743745
lstm_module.parameters()
744746
)
745747

748+
def test_python_cudnn(self):
749+
lstm_module = LSTMModule(
750+
input_size=3,
751+
hidden_size=12,
752+
batch_first=True,
753+
dropout=0,
754+
num_layers=2,
755+
in_keys=["observation", "hidden0", "hidden1"],
756+
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
757+
).set_recurrent_mode(True)
758+
obs = torch.rand(10, 20, 3)
759+
760+
hidden0 = torch.rand(10, 20, 2, 12)
761+
hidden1 = torch.rand(10, 20, 2, 12)
762+
763+
is_init = torch.zeros(10, 20, dtype=torch.bool)
764+
assert isinstance(lstm_module.lstm, nn.LSTM)
765+
outs_ref = lstm_module(
766+
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
767+
)
768+
769+
lstm_module.make_python_based()
770+
assert isinstance(lstm_module.lstm, torchrl.modules.LSTM)
771+
outs_rl = lstm_module(
772+
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
773+
)
774+
torch.testing.assert_close(outs_ref, outs_rl)
775+
776+
lstm_module.make_cudnn_based()
777+
assert isinstance(lstm_module.lstm, nn.LSTM)
778+
outs_cudnn = lstm_module(
779+
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
780+
)
781+
torch.testing.assert_close(outs_ref, outs_cudnn)
782+
746783
def test_noncontiguous(self):
747784
lstm_module = LSTMModule(
748785
input_size=3,
@@ -1088,6 +1125,34 @@ def test_set_temporal_mode(self):
10881125
gru_module.parameters()
10891126
)
10901127

1128+
def test_python_cudnn(self):
1129+
gru_module = GRUModule(
1130+
input_size=3,
1131+
hidden_size=12,
1132+
batch_first=True,
1133+
dropout=0,
1134+
num_layers=2,
1135+
in_keys=["observation", "hidden0"],
1136+
out_keys=["intermediate", ("next", "hidden0")],
1137+
).set_recurrent_mode(True)
1138+
obs = torch.rand(10, 20, 3)
1139+
1140+
hidden0 = torch.rand(10, 20, 2, 12)
1141+
1142+
is_init = torch.zeros(10, 20, dtype=torch.bool)
1143+
assert isinstance(gru_module.gru, nn.GRU)
1144+
outs_ref = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
1145+
1146+
gru_module.make_python_based()
1147+
assert isinstance(gru_module.gru, torchrl.modules.GRU)
1148+
outs_rl = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
1149+
torch.testing.assert_close(outs_ref, outs_rl)
1150+
1151+
gru_module.make_cudnn_based()
1152+
assert isinstance(gru_module.gru, nn.GRU)
1153+
outs_cudnn = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
1154+
torch.testing.assert_close(outs_ref, outs_cudnn)
1155+
10911156
def test_noncontiguous(self):
10921157
gru_module = GRUModule(
10931158
input_size=3,

torchrl/_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,17 @@
4040
# Remove all attached handlers
4141
while logger.hasHandlers():
4242
logger.removeHandler(logger.handlers[0])
43-
console_handler = logging.StreamHandler()
43+
stream_handlers = {
44+
"stdout": sys.stdout,
45+
"stderr": sys.stderr,
46+
}
47+
TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
48+
if TORCHRL_CONSOLE_STREAM:
49+
stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM]
50+
else:
51+
stream_handler = None
52+
console_handler = logging.StreamHandler(stream=stream_handler)
53+
4454
console_handler.setLevel(logging.INFO)
4555
formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s")
4656
console_handler.setFormatter(formatter)
@@ -86,17 +96,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
8696
val[2] = N
8797

8898
@staticmethod
89-
def print(prefix=None): # noqa: T202
99+
def print(prefix=None) -> str: # noqa: T202
100+
"""Prints the state of the timer.
101+
102+
Returns:
103+
the string printed using the logger.
104+
"""
90105
keys = list(timeit._REG)
91106
keys.sort()
107+
string = []
92108
for name in keys:
93109
strings = []
94110
if prefix:
95111
strings.append(prefix)
96112
strings.append(
97113
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
98114
)
99-
logger.info(" -- ".join(strings))
115+
string.append(" -- ".join(strings))
116+
logger.info(string[-1])
117+
return "\n".join(string)
100118

101119
@classmethod
102120
def todict(cls, percall=True):

torchrl/modules/tensordict_module/rnn.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from tensordict.base import NO_DEFAULT
1414

15-
from tensordict.nn import TensorDictModuleBase as ModuleBase
15+
from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase
1616
from tensordict.utils import expand_as_right, prod, set_lazy_legacy
1717

1818
from torch import nn, Tensor
@@ -467,6 +467,8 @@ def __init__(
467467
raise ValueError("The input lstm must have batch_first=True.")
468468
if bidirectional:
469469
raise ValueError("The input lstm cannot be bidirectional.")
470+
if not hidden_size:
471+
raise ValueError("hidden_size must be passed.")
470472
if python_based:
471473
lstm = LSTM(
472474
input_size=input_size,
@@ -524,6 +526,58 @@ def __init__(
524526
self.out_keys = out_keys
525527
self._recurrent_mode = False
526528

529+
def make_python_based(self) -> LSTMModule:
530+
"""Transforms the LSTM layer in its python-based version.
531+
532+
Returns:
533+
self
534+
535+
"""
536+
if isinstance(self.lstm, LSTM):
537+
return self
538+
lstm = LSTM(
539+
input_size=self.lstm.input_size,
540+
hidden_size=self.lstm.hidden_size,
541+
num_layers=self.lstm.num_layers,
542+
bias=self.lstm.bias,
543+
dropout=self.lstm.dropout,
544+
proj_size=self.lstm.proj_size,
545+
device="meta",
546+
batch_first=self.lstm.batch_first,
547+
bidirectional=self.lstm.bidirectional,
548+
)
549+
from tensordict import from_module
550+
551+
from_module(self.lstm).to_module(lstm)
552+
self.lstm = lstm
553+
return self
554+
555+
def make_cudnn_based(self) -> LSTMModule:
556+
"""Transforms the LSTM layer in its CuDNN-based version.
557+
558+
Returns:
559+
self
560+
561+
"""
562+
if isinstance(self.lstm, nn.LSTM):
563+
return self
564+
lstm = nn.LSTM(
565+
input_size=self.lstm.input_size,
566+
hidden_size=self.lstm.hidden_size,
567+
num_layers=self.lstm.num_layers,
568+
bias=self.lstm.bias,
569+
dropout=self.lstm.dropout,
570+
proj_size=self.lstm.proj_size,
571+
device="meta",
572+
batch_first=self.lstm.batch_first,
573+
bidirectional=self.lstm.bidirectional,
574+
)
575+
from tensordict import from_module
576+
577+
from_module(self.lstm).to_module(lstm)
578+
self.lstm = lstm
579+
return self
580+
527581
def make_tensordict_primer(self):
528582
"""Makes a tensordict primer for the environment.
529583
@@ -644,6 +698,7 @@ def set_recurrent_mode(self, mode: bool = True):
644698
out._recurrent_mode = mode
645699
return out
646700

701+
@dispatch
647702
def forward(self, tensordict: TensorDictBase):
648703
# we want to get an error if the value input is missing, but not the hidden states
649704
defaults = [NO_DEFAULT, None, None]
@@ -1273,6 +1328,56 @@ def __init__(
12731328
self.out_keys = out_keys
12741329
self._recurrent_mode = False
12751330

1331+
def make_python_based(self) -> GRUModule:
1332+
"""Transforms the GRU layer in its python-based version.
1333+
1334+
Returns:
1335+
self
1336+
1337+
"""
1338+
if isinstance(self.gru, GRU):
1339+
return self
1340+
gru = GRU(
1341+
input_size=self.gru.input_size,
1342+
hidden_size=self.gru.hidden_size,
1343+
num_layers=self.gru.num_layers,
1344+
bias=self.gru.bias,
1345+
dropout=self.gru.dropout,
1346+
device="meta",
1347+
batch_first=self.gru.batch_first,
1348+
bidirectional=self.gru.bidirectional,
1349+
)
1350+
from tensordict import from_module
1351+
1352+
from_module(self.gru).to_module(gru)
1353+
self.gru = gru
1354+
return self
1355+
1356+
def make_cudnn_based(self) -> GRUModule:
1357+
"""Transforms the GRU layer in its CuDNN-based version.
1358+
1359+
Returns:
1360+
self
1361+
1362+
"""
1363+
if isinstance(self.gru, nn.GRU):
1364+
return self
1365+
gru = nn.GRU(
1366+
input_size=self.gru.input_size,
1367+
hidden_size=self.gru.hidden_size,
1368+
num_layers=self.gru.num_layers,
1369+
bias=self.gru.bias,
1370+
dropout=self.gru.dropout,
1371+
device="meta",
1372+
batch_first=self.gru.batch_first,
1373+
bidirectional=self.gru.bidirectional,
1374+
)
1375+
from tensordict import from_module
1376+
1377+
from_module(self.gru).to_module(gru)
1378+
self.gru = gru
1379+
return self
1380+
12761381
def make_tensordict_primer(self):
12771382
"""Makes a tensordict primer for the environment.
12781383
@@ -1389,6 +1494,7 @@ def set_recurrent_mode(self, mode: bool = True):
13891494
out._recurrent_mode = mode
13901495
return out
13911496

1497+
@dispatch
13921498
@set_lazy_legacy(False)
13931499
def forward(self, tensordict: TensorDictBase):
13941500
# we want to get an error if the value input is missing, but not the hidden states

0 commit comments

Comments
 (0)