Skip to content

Commit c0187a9

Browse files
author
Vincent Moens
committed
[Doc] Tutorial on exporting TorchRL models
ghstack-source-id: b93146e Pull Request resolved: #2557
1 parent 165163a commit c0187a9

File tree

5 files changed

+432
-1
lines changed

5 files changed

+432
-1
lines changed

docs/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ memory_profiler
2525
pyrender
2626
pytest
2727
vmas
28+
onnxscript
29+
onnxruntime
30+
onnx

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Intermediate
104104
tutorials/pretrained_models
105105
tutorials/dqn_with_rnn
106106
tutorials/rb_tutorial
107+
tutorials/export
107108

108109
Advanced
109110
--------

torchrl/modules/models/models.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ def __init__(
201201
if not isinstance(out_features, Number):
202202
_out_features_num = prod(out_features)
203203
self.out_features = out_features
204+
self._reshape_out = not isinstance(
205+
self.out_features, (int, torch.SymInt, Number)
206+
)
204207
self._out_features_num = _out_features_num
205208
self.activation_class = activation_class
206209
self.norm_class = norm_class
@@ -302,7 +305,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
302305
inputs = (torch.cat([*inputs], -1),)
303306

304307
out = super().forward(*inputs)
305-
if not isinstance(self.out_features, Number):
308+
if self._reshape_out:
306309
out = out.view(*out.shape[:-1], *self.out_features)
307310
return out
308311

@@ -549,6 +552,27 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
549552
out = out.unflatten(0, batch)
550553
return out
551554

555+
@classmethod
556+
def default_atari_dqn(cls, num_actions: int):
557+
"""Returns the default DQN as presented in the seminal DQN paper.
558+
559+
Args:
560+
num_actions (int): the action space of the atari game.
561+
562+
"""
563+
cnn = ConvNet(
564+
activation_class=torch.nn.ReLU,
565+
num_cells=[32, 64, 64],
566+
kernel_sizes=[8, 4, 3],
567+
strides=[4, 2, 1],
568+
)
569+
mlp = MLP(
570+
activation_class=torch.nn.ReLU,
571+
out_features=num_actions,
572+
num_cells=[512],
573+
)
574+
return nn.Sequential(cnn, mlp)
575+
552576

553577
Conv2dNet = ConvNet
554578

torchrl/modules/tensordict_module/exploration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
152152
out = action_tensordict.get(action_key)
153153
eps = self.eps.item()
154154
cond = torch.rand(action_tensordict.shape, device=out.device) < eps
155+
# cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
155156
cond = expand_as_right(cond, out)
156157
spec = self.spec
157158
if spec is not None:

0 commit comments

Comments
 (0)