Skip to content

Commit f8788b1

Browse files
author
Vincent Moens
authored
[BugFix] Fix tutos (#1648)
1 parent e7630f1 commit f8788b1

File tree

5 files changed

+42
-14
lines changed

5 files changed

+42
-14
lines changed

torchrl/collectors/collectors.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,9 @@ def __init__(
567567
self.policy_weights = TensorDict({}, [])
568568

569569
self.env: EnvBase = self.env.to(self.device)
570-
self.max_frames_per_traj = max_frames_per_traj
570+
self.max_frames_per_traj = (
571+
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
572+
)
571573
if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
572574
# let's check that there is no StepCounter yet
573575
for key in self.env.output_spec.keys(True, True):
@@ -595,9 +597,13 @@ def __init__(
595597
f"This means {frames_per_batch - remainder} additional frames will be collected."
596598
"To silence this message, set the environment variable RL_WARNINGS to False."
597599
)
598-
self.total_frames = total_frames
600+
self.total_frames = (
601+
int(total_frames) if total_frames != float("inf") else total_frames
602+
)
599603
self.reset_at_each_iter = reset_at_each_iter
600-
self.init_random_frames = init_random_frames
604+
self.init_random_frames = (
605+
int(init_random_frames) if init_random_frames is not None else 0
606+
)
601607
if (
602608
init_random_frames is not None
603609
and init_random_frames % frames_per_batch != 0
@@ -620,7 +626,7 @@ def __init__(
620626
f" ({-(-frames_per_batch // self.n_env) * self.n_env})."
621627
"To silence this message, set the environment variable RL_WARNINGS to False."
622628
)
623-
self.requested_frames_per_batch = frames_per_batch
629+
self.requested_frames_per_batch = int(frames_per_batch)
624630
self.frames_per_batch = -(-frames_per_batch // self.n_env)
625631
self.exploration_type = (
626632
exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
@@ -1234,11 +1240,15 @@ def device_err_msg(device_name, devices_list):
12341240
f"This means {frames_per_batch - remainder} additional frames will be collected."
12351241
"To silence this message, set the environment variable RL_WARNINGS to False."
12361242
)
1237-
self.total_frames = total_frames
1243+
self.total_frames = (
1244+
int(total_frames) if total_frames != float("inf") else total_frames
1245+
)
12381246
self.reset_at_each_iter = reset_at_each_iter
12391247
self.postprocs = postproc
1240-
self.max_frames_per_traj = max_frames_per_traj
1241-
self.requested_frames_per_batch = frames_per_batch
1248+
self.max_frames_per_traj = (
1249+
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
1250+
)
1251+
self.requested_frames_per_batch = int(frames_per_batch)
12421252
self.reset_when_done = reset_when_done
12431253
if split_trajs is None:
12441254
split_trajs = False
@@ -1247,7 +1257,9 @@ def device_err_msg(device_name, devices_list):
12471257
"Cannot split trajectories when reset_when_done is False."
12481258
)
12491259
self.split_trajs = split_trajs
1250-
self.init_random_frames = init_random_frames
1260+
self.init_random_frames = (
1261+
int(init_random_frames) if init_random_frames is not None else 0
1262+
)
12511263
self.update_at_each_batch = update_at_each_batch
12521264
self.exploration_type = exploration_type
12531265
self.frames_per_worker = np.inf

torchrl/envs/transforms/r3m.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _init(self):
302302
transforms.append(resize)
303303

304304
# R3M
305-
if out_keys is None:
305+
if out_keys in (None, []):
306306
if stack_images:
307307
out_keys = ["r3m_vec"]
308308
else:

torchrl/envs/transforms/vip.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
from typing import List, Optional, Union
76

87
import torch
@@ -277,7 +276,7 @@ def _init(self):
277276
transforms.append(resize)
278277

279278
# VIP
280-
if out_keys is None:
279+
if out_keys in (None, []):
281280
if stack_images:
282281
out_keys = ["vip_vec"]
283282
else:

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def get_replay_buffer(buffer_size, n_optim, batch_size):
390390

391391

392392
def get_collector(
393-
obs_norm_sd,
393+
stats,
394394
num_collectors,
395395
actor_explore,
396396
frames_per_batch,
@@ -399,7 +399,7 @@ def get_collector(
399399
):
400400
data_collector = MultiaSyncDataCollector(
401401
[
402-
make_env(parallel=True, obs_norm_sd=obs_norm_sd),
402+
make_env(parallel=True, obs_norm_sd=stats),
403403
]
404404
* num_collectors,
405405
policy=actor_explore,
@@ -566,7 +566,12 @@ def get_loss_module(actor, gamma):
566566
loss_module, target_net_updater = get_loss_module(actor, gamma)
567567

568568
collector = get_collector(
569-
stats, num_collectors, actor_explore, frames_per_batch, total_frames, device
569+
stats=stats,
570+
num_collectors=num_collectors,
571+
actor_explore=actor_explore,
572+
frames_per_batch=frames_per_batch,
573+
total_frames=total_frames,
574+
device=device,
570575
)
571576
optimizer = torch.optim.Adam(
572577
loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas

tutorials/sphinx-tutorials/pendulum.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,12 @@ class SinTransform(Transform):
652652
def _apply_transform(self, obs: torch.Tensor) -> None:
653653
return obs.sin()
654654

655+
# The transform must also modify the data at reset time
656+
def _reset(
657+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
658+
) -> TensorDictBase:
659+
return self._call(tensordict_reset)
660+
655661
# _apply_to_composite will execute the observation spec transform across all
656662
# in_keys/out_keys pairs and write the result in the observation_spec which
657663
# is of type ``Composite``
@@ -670,6 +676,12 @@ class CosTransform(Transform):
670676
def _apply_transform(self, obs: torch.Tensor) -> None:
671677
return obs.cos()
672678

679+
# The transform must also modify the data at reset time
680+
def _reset(
681+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
682+
) -> TensorDictBase:
683+
return self._call(tensordict_reset)
684+
673685
# _apply_to_composite will execute the observation spec transform across all
674686
# in_keys/out_keys pairs and write the result in the observation_spec which
675687
# is of type ``Composite``

0 commit comments

Comments
 (0)