Skip to content

Commit 802f0e4

Browse files
Vincent Moensskandermoallamatteobettini
authored
[Feature] Gym compatibility: Terminal and truncated (#1539)
Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
1 parent 18b33fe commit 802f0e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3583
-1218
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ python -m torch.utils.collect_env
178178
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
179179
export MKL_THREADING_LAYER=GNU
180180
export CKPT_BACKEND=torch
181-
181+
export MAX_IDLE_COUNT=100
182182

183183
pytest test/smoke_test.py -v --durations 200
184184
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
185185
if [ "${CU_VERSION:-}" != cpu ] ; then
186186
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
187-
--instafail --durations 200 --ignore test/test_rlhf.py
187+
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py
188188
else
189189
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
190-
--instafail --durations 200 --ignore test/test_rlhf.py --ignore test/test_distributed.py
190+
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py
191191
fi
192192

193193
coverage combine

.github/unittest/linux_examples/scripts/run_local.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/bin/bash
22

33
set -e
4+
set -v
45

56
# Read script from line 29
67
filename=".github/unittest/linux_examples/scripts/run_test.sh"
@@ -12,7 +13,7 @@ script="set -e"$'\n'"$script"
1213
script="${script//cuda:0/cpu}"
1314

1415
# Remove any instances of ".github/unittest/helpers/coverage_run_parallel.py"
15-
script="${script//.circleci\/unittest\/helpers\/coverage_run_parallel.py}"
16+
script="${script//.github\/unittest\/helpers\/coverage_run_parallel.py}"
1617
script="${script//coverage combine}"
1718
script="${script//coverage xml -i}"
1819

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco
5353
collector.total_frames=40 \
5454
collector.frames_per_batch=20 \
5555
loss.mini_batch_size=10 \
56-
loss.ppo_epochs=1 \
56+
loss.ppo_epochs=2 \
5757
logger.backend= \
58-
logger.test_interval=40
58+
logger.test_interval=10
5959
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \
6060
collector.total_frames=80 \
6161
collector.frames_per_batch=20 \
6262
loss.mini_batch_size=20 \
63-
loss.ppo_epochs=1 \
63+
loss.ppo_epochs=2 \
6464
logger.backend= \
65-
logger.test_interval=40
65+
logger.test_interval=10
6666
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
6767
collector.total_frames=48 \
6868
collector.init_random_frames=10 \
@@ -126,6 +126,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
126126
optimization.utd_ratio=1 \
127127
replay_buffer.size=120 \
128128
env.name=Pendulum-v1 \
129+
network.device=cuda:0 \
129130
logger.backend=
130131
# logger.record_video=True \
131132
# logger.record_frames=4 \
@@ -225,6 +226,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
225226
collector.num_workers=2 \
226227
collector.env_per_collector=1 \
227228
collector.collector_device=cuda:0 \
229+
network.device=cuda:0 \
228230
optimization.batch_size=10 \
229231
optimization.utd_ratio=1 \
230232
replay_buffer.size=120 \

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ do
5050
conda activate ./cloned_env
5151

5252
echo "Testing gym version: ${GYM_VERSION}"
53+
# handling https://github.com/openai/gym/issues/3202
54+
pip3 install wheel==0.38.4
5355
pip3 install gym==$GYM_VERSION
5456
$DIR/run_test.sh
5557

@@ -67,6 +69,7 @@ do
6769
conda activate ./cloned_env
6870

6971
echo "Testing gym version: ${GYM_VERSION}"
72+
pip3 install wheel==0.38.4
7073
pip3 install 'gym[atari]'==$GYM_VERSION
7174
pip3 install ale-py==0.7
7275
$DIR/run_test.sh

docs/source/reference/envs.rst

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Each env will have the following attributes:
3636
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing
3737
the reward spec.
3838
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
39-
the done-flag spec.
39+
the done-flag spec. See the section on trajectory termination below.
4040
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
4141
all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`).
4242
It is locked and should not be modified directly.
@@ -79,22 +79,25 @@ The following figure summarizes how a rollout is executed in torchrl.
7979

8080
In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method,
8181
then populated with an action by the policy before being passed to the
82-
:meth:`~.EnvBase.step` method which writes the observations, done flag and
82+
:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and
8383
reward under the ``"next"`` entry. The result of this call is stored for
8484
delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp`
8585
function.
8686

8787
.. note::
88-
89-
The Gym(nasium) API recently shifted to a splitting of the ``"done"`` state
90-
into a ``terminated`` (the env is done and results should not be trusted)
91-
and ``truncated`` (the maximum number of steps is reached) flags.
92-
In TorchRL, ``"done"`` usually refers to ``"terminated"``. Truncation is
93-
achieved via the :class:`~.StepCounter` transform class, and the output
94-
key will be ``"truncated"`` if not chosen to be something else (e.g.
95-
``StepCounter(max_steps=100, truncated_key="done")``).
96-
TorchRL's collectors and rollout methods will be looking for one of these
97-
keys when assessing if the env should be reset.
88+
In general, all TorchRL environment have a ``"done"`` and ``"terminated"``
89+
entry in their output tensordict. If they are not present by design,
90+
the :class:`~.EnvBase` metaclass will ensure that every done or terminated
91+
is flanked with its dual.
92+
In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory
93+
signals and should be interpreted as "the last step of a trajectory" or
94+
equivalently "a signal indicating the need to reset".
95+
If the environment provides it (eg, Gymnasium), the truncation entry is also
96+
written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry.
97+
If the environment carries a single value, it will interpreted as a ``"terminated"``
98+
signal by default.
99+
By default, TorchRL's collectors and rollout methods will be looking for the ``"done"``
100+
entry to assess if the environment should be reset.
98101

99102
.. note::
100103

@@ -172,12 +175,13 @@ It is also possible to reset some but not all of the environments:
172175
:caption: Parallel environment reset
173176
174177
>>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4])
175-
>>> env.reset(tensordict)
178+
>>> env.reset(tensordict) # eliminates the "_reset" entry
176179
TensorDict(
177180
fields={
181+
terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool),
178182
done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
179183
pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
180-
_reset: Tensor(torch.Size([4, 1]), dtype=torch.bool)},
184+
truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool),
181185
batch_size=torch.Size([4]),
182186
device=None,
183187
is_shared=True)
@@ -238,7 +242,7 @@ Some of the main differences between these paradigms include:
238242

239243
- **observation** can be per-agent and also have some shared components
240244
- **reward** can be per-agent or shared
241-
- **done** can be per-agent or shared
245+
- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared.
242246

243247
TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier.
244248
In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict.
@@ -586,6 +590,7 @@ Helpers
586590
exploration_type
587591
check_env_specs
588592
make_composite_from_td
593+
terminated_or_truncated
589594

590595
Domain-specific
591596
---------------

examples/a2c/a2c_atari.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,6 @@ def main(cfg: "DictConfig"): # noqa: F821
124124
}
125125
)
126126

127-
# Apply episodic end of life
128-
data["done"].copy_(data["end_of_life"])
129-
data["next", "done"].copy_(data["next", "end_of_life"])
130-
131127
losses = TensorDict({}, batch_size=[num_mini_batches])
132128
training_start = time.time()
133129

examples/decision_transformer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):
232232
batch_size=rb_cfg.batch_size,
233233
sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False),
234234
transform=transforms,
235-
use_timeout_as_done=True,
235+
use_truncated_as_done=True,
236236
)
237237
full_data = data._get_dataset_from_env(rb_cfg.dataset, {})
238238
loc = full_data["observation"].mean(axis=0).float()

examples/ppo/ppo_atari.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def main(cfg: "DictConfig"): # noqa: F821
7878
normalize_advantage=True,
7979
)
8080

81+
# use end-of-life as done key
82+
loss_module.set_keys(done="eol")
83+
8184
# Create optimizer
8285
optim = torch.optim.Adam(
8386
loss_module.parameters(),
@@ -109,6 +112,18 @@ def main(cfg: "DictConfig"): # noqa: F821
109112
)
110113

111114
sampling_start = time.time()
115+
116+
# extract cfg variables
117+
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
118+
cfg_optim_anneal_lr = cfg.optim.anneal_lr
119+
cfg_optim_lr = cfg.optim.lr
120+
cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
121+
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
122+
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
123+
cfg_optim_max_grad_norm = cfg.optim.max_grad_norm
124+
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
125+
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
126+
112127
for i, data in enumerate(collector):
113128

114129
log_info = {}
@@ -120,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821
120135
# Get training rewards and episode lengths
121136
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
122137
if len(episode_rewards) > 0:
123-
episode_length = data["next", "step_count"][data["next", "done"]]
138+
episode_length = data["next", "step_count"][data["next", "stop"]]
124139
log_info.update(
125140
{
126141
"train/reward": episode_rewards.mean().item(),
@@ -129,13 +144,8 @@ def main(cfg: "DictConfig"): # noqa: F821
129144
}
130145
)
131146

132-
# Apply episodic end of life
133-
data["done"].copy_(data["end_of_life"])
134-
data["next", "done"].copy_(data["next", "end_of_life"])
135-
136-
losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
137147
training_start = time.time()
138-
for j in range(cfg.loss.ppo_epochs):
148+
for j in range(cfg_loss_ppo_epochs):
139149

140150
# Compute GAE
141151
with torch.no_grad():
@@ -149,12 +159,12 @@ def main(cfg: "DictConfig"): # noqa: F821
149159

150160
# Linearly decrease the learning rate and clip epsilon
151161
alpha = 1.0
152-
if cfg.optim.anneal_lr:
162+
if cfg_optim_anneal_lr:
153163
alpha = 1 - (num_network_updates / total_network_updates)
154164
for group in optim.param_groups:
155-
group["lr"] = cfg.optim.lr * alpha
156-
if cfg.loss.anneal_clip_epsilon:
157-
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
165+
group["lr"] = cfg_optim_lr * alpha
166+
if cfg_loss_anneal_clip_eps:
167+
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
158168
num_network_updates += 1
159169

160170
# Get a data batch
@@ -172,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821
172182
# Backward pass
173183
loss_sum.backward()
174184
torch.nn.utils.clip_grad_norm_(
175-
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
185+
list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm
176186
)
177187

178188
# Update the networks
@@ -181,15 +191,15 @@ def main(cfg: "DictConfig"): # noqa: F821
181191

182192
# Get training losses and times
183193
training_time = time.time() - training_start
184-
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
185-
for key, value in losses.items():
194+
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
195+
for key, value in losses_mean.items():
186196
log_info.update({f"train/{key}": value.item()})
187197
log_info.update(
188198
{
189-
"train/lr": alpha * cfg.optim.lr,
199+
"train/lr": alpha * cfg_optim_lr,
190200
"train/sampling_time": sampling_time,
191201
"train/training_time": training_time,
192-
"train/clip_epsilon": alpha * cfg.loss.clip_epsilon,
202+
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
193203
}
194204
)
195205

@@ -201,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821
201211
actor.eval()
202212
eval_start = time.time()
203213
test_rewards = eval_model(
204-
actor, test_env, num_episodes=cfg.logger.num_test_episodes
214+
actor, test_env, num_episodes=cfg_logger_num_test_episodes
205215
)
206216
eval_time = time.time() - eval_start
207217
log_info.update(

examples/ppo/ppo_mujoco.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def main(cfg: "DictConfig"): # noqa: F821
100100
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
101101

102102
sampling_start = time.time()
103+
104+
# extract cfg variables
105+
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
106+
cfg_optim_anneal_lr = cfg.optim.anneal_lr
107+
cfg_optim_lr = cfg.optim.lr
108+
cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
109+
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
110+
cfg_logger_test_interval = cfg.logger.test_interval
111+
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
112+
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
113+
103114
for i, data in enumerate(collector):
104115

105116
log_info = {}
@@ -120,9 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821
120131
}
121132
)
122133

123-
losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
124134
training_start = time.time()
125-
for j in range(cfg.loss.ppo_epochs):
135+
for j in range(cfg_loss_ppo_epochs):
126136

127137
# Compute GAE
128138
with torch.no_grad():
@@ -136,14 +146,14 @@ def main(cfg: "DictConfig"): # noqa: F821
136146

137147
# Linearly decrease the learning rate and clip epsilon
138148
alpha = 1.0
139-
if cfg.optim.anneal_lr:
149+
if cfg_optim_anneal_lr:
140150
alpha = 1 - (num_network_updates / total_network_updates)
141151
for group in actor_optim.param_groups:
142-
group["lr"] = cfg.optim.lr * alpha
152+
group["lr"] = cfg_optim_lr * alpha
143153
for group in critic_optim.param_groups:
144-
group["lr"] = cfg.optim.lr * alpha
145-
if cfg.loss.anneal_clip_epsilon:
146-
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
154+
group["lr"] = cfg_optim_lr * alpha
155+
if cfg_loss_anneal_clip_eps:
156+
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
147157
num_network_updates += 1
148158

149159
# Forward pass PPO loss
@@ -166,27 +176,27 @@ def main(cfg: "DictConfig"): # noqa: F821
166176

167177
# Get training losses and times
168178
training_time = time.time() - training_start
169-
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
170-
for key, value in losses.items():
179+
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
180+
for key, value in losses_mean.items():
171181
log_info.update({f"train/{key}": value.item()})
172182
log_info.update(
173183
{
174-
"train/lr": alpha * cfg.optim.lr,
184+
"train/lr": alpha * cfg_optim_lr,
175185
"train/sampling_time": sampling_time,
176186
"train/training_time": training_time,
177-
"train/clip_epsilon": alpha * cfg.loss.clip_epsilon,
187+
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
178188
}
179189
)
180190

181191
# Get test rewards
182192
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
183-
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
193+
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
184194
i * frames_in_batch
185-
) // cfg.logger.test_interval:
195+
) // cfg_logger_test_interval:
186196
actor.eval()
187197
eval_start = time.time()
188198
test_rewards = eval_model(
189-
actor, test_env, num_episodes=cfg.logger.num_test_episodes
199+
actor, test_env, num_episodes=cfg_logger_num_test_episodes
190200
)
191201
eval_time = time.time() - eval_start
192202
log_info.update(

0 commit comments

Comments
 (0)