Skip to content

Commit 50f0db0

Browse files
authored
[BugFix, Doc] Fix tutos (#1107)
1 parent 25370f7 commit 50f0db0

File tree

14 files changed

+86
-25
lines changed

14 files changed

+86
-25
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ jobs:
7272
id: build_doc
7373
run: |
7474
cd ./docs
75-
#timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
76-
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build
75+
timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
7776
cd ..
7877
- name: Install rsync 📚
7978
run: |

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Intermediate
5050

5151
tutorials/torch_envs
5252
tutorials/pretrained_models
53+
tutorials/dqn_with_rnn.py
5354

5455
Advanced
5556
--------

torchrl/collectors/collectors.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,9 @@ def __init__(
496496
):
497497
self.closed = True
498498

499-
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
499+
exploration_type = _convert_exploration_type(
500+
exploration_mode=exploration_mode, exploration_type=exploration_type
501+
)
500502
if create_env_kwargs is None:
501503
create_env_kwargs = {}
502504
if not isinstance(create_env_fn, EnvBase):
@@ -1049,7 +1051,9 @@ def __init__(
10491051
devices=None,
10501052
storing_devices=None,
10511053
):
1052-
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
1054+
exploration_type = _convert_exploration_type(
1055+
exploration_mode=exploration_mode, exploration_type=exploration_type
1056+
)
10531057
self.closed = True
10541058
self.create_env_fn = create_env_fn
10551059
self.num_workers = len(create_env_fn)

torchrl/collectors/distributed/generic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def __init__(
385385
launcher="submitit",
386386
tcp_port=None,
387387
):
388-
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
388+
exploration_type = _convert_exploration_type(
389+
exploration_mode=exploration_mode, exploration_type=exploration_type
390+
)
389391

390392
if collector_class == "async":
391393
collector_class = MultiaSyncDataCollector

torchrl/collectors/distributed/rpc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def __init__(
238238
visible_devices=None,
239239
tensorpipe_options=None,
240240
):
241-
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
241+
exploration_type = _convert_exploration_type(
242+
exploration_mode=exploration_mode, exploration_type=exploration_type
243+
)
242244
if collector_class == "async":
243245
collector_class = MultiaSyncDataCollector
244246
elif collector_class == "sync":

torchrl/collectors/distributed/sync.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def __init__(
242242
launcher="submitit",
243243
tcp_port=None,
244244
):
245-
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
245+
exploration_type = _convert_exploration_type(
246+
exploration_mode=exploration_mode, exploration_type=exploration_type
247+
)
246248

247249
if collector_class == "async":
248250
collector_class = MultiaSyncDataCollector

torchrl/envs/libs/gym.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,37 @@ class GymWrapper(GymLikeEnv):
337337
git_url = "https://github.com/openai/gym"
338338
libname = "gym"
339339

340+
@staticmethod
341+
def get_library_name(env):
342+
# try gym
343+
try:
344+
import gym
345+
346+
if isinstance(env.action_space, gym.spaces.space.Space):
347+
return gym
348+
except ImportError:
349+
pass
350+
try:
351+
import gymnasium
352+
353+
if isinstance(env.action_space, gymnasium.spaces.space.Space):
354+
return gymnasium
355+
except ImportError:
356+
pass
357+
raise RuntimeError(
358+
f"Could not find the library of env {env}. Please file an issue on torchrl github repo."
359+
)
360+
340361
def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
341362
if env is not None:
342363
kwargs["env"] = env
343364
self._seed_calls_reset = None
344365
self._categorical_action_encoding = categorical_action_encoding
345-
super().__init__(**kwargs)
366+
if "env" in kwargs:
367+
with set_gym_backend(self.get_library_name(kwargs["env"])):
368+
super().__init__(**kwargs)
369+
else:
370+
super().__init__(**kwargs)
346371

347372
def _check_kwargs(self, kwargs: Dict):
348373
if "env" not in kwargs:

torchrl/envs/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set}
3232

3333

34-
def _convert_exploration_type(exploration_mode, exploration_type):
34+
def _convert_exploration_type(*, exploration_mode, exploration_type):
3535
if exploration_mode is not None:
3636
return ExplorationType.from_str(exploration_mode)
3737
return exploration_type

tutorials/sphinx-tutorials/coding_ddpg.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,8 @@ def _loss_actor(
268268
tensordict,
269269
) -> torch.Tensor:
270270
td_copy = tensordict.select(*self.actor_in_keys)
271-
# Get an action from the actor network
272-
td_copy = self.actor_network(
273-
td_copy,
274-
)
271+
# Get an action from the actor network: since we made it functional, we need to pass the params
272+
td_copy = self.actor_network(td_copy, params=self.actor_network_params)
275273
# get the value associated with that action
276274
td_copy = self.value_network(
277275
td_copy,
@@ -482,6 +480,7 @@ def make_env(from_pixels=False):
482480
CatTensors,
483481
DoubleToFloat,
484482
EnvCreator,
483+
InitTracker,
485484
ObservationNorm,
486485
ParallelEnv,
487486
RewardScaling,
@@ -536,6 +535,9 @@ def make_transformed_env(
536535

537536
env.append_transform(StepCounter(max_frames_per_traj))
538537

538+
# We need a marker for the start of trajectories for our OU exploration:
539+
env.append_transform(InitTracker())
540+
539541
return env
540542

541543

@@ -889,7 +891,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
889891
record_frames=1000,
890892
policy_exploration=actor_model_explore,
891893
environment=environment,
892-
exploration_type="mode",
894+
exploration_type=ExplorationType.MEAN,
893895
record_interval=record_interval,
894896
)
895897
return recorder_obj

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ def get_collector(
438438

439439

440440
def get_loss_module(actor, gamma):
441-
loss_module = DQNLoss(actor, gamma=gamma, delay_value=True)
441+
loss_module = DQNLoss(actor, delay_value=True)
442+
loss_module.make_value_estimator(gamma=gamma)
442443
target_updater = SoftUpdate(loss_module)
443444
return loss_module, target_updater
444445

@@ -617,7 +618,7 @@ def get_loss_module(actor, gamma):
617618
frame_skip=1,
618619
policy_exploration=actor_explore,
619620
environment=test_env,
620-
exploration_type="mode",
621+
exploration_type=ExplorationType.MODE,
621622
log_keys=[("next", "reward")],
622623
out_keys={("next", "reward"): "rewards"},
623624
log_pbar=True,

0 commit comments

Comments
 (0)