Skip to content

Commit 09f71b1

Browse files
authored
[Feature] CatFrames for offline data (#1122)
1 parent 0452133 commit 09f71b1

File tree

12 files changed

+371
-149
lines changed

12 files changed

+371
-149
lines changed

.circleci/unittest/linux_examples/scripts/run_test.sh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_
2929
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20
3030

3131
# With batched environments
32+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
33+
env.num_envs=1 \
34+
collector.total_frames=48 \
35+
collector.frames_per_batch=16 \
36+
collector.collector_device=cuda:0 \
37+
optim.device=cuda:0 \
38+
loss.mini_batch_size=10 \
39+
loss.ppo_epochs=1 \
40+
logger.backend= \
41+
logger.log_interval=4 \
42+
optim.lr_scheduler=False
3243
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
3344
total_frames=48 \
3445
init_random_frames=10 \
@@ -86,17 +97,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
8697
record_video=True \
8798
record_frames=4 \
8899
buffer_size=120
89-
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
90-
env.num_envs=1 \
91-
collector.total_frames=48 \
92-
collector.frames_per_batch=16 \
93-
collector.collector_device=cuda:0 \
94-
optim.device=cuda:0 \
95-
loss.mini_batch_size=10 \
96-
loss.ppo_epochs=1 \
97-
logger.backend= \
98-
logger.log_interval=4 \
99-
optim.lr_scheduler=False
100100
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
101101
total_frames=200 \
102102
init_random_frames=10 \

.github/workflows/nightly_build.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ jobs:
119119
- name: Install test dependencies
120120
run: |
121121
python3 -mpip install numpy pytest --no-cache-dir
122+
- name: Install tensordict
123+
run: |
124+
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
122125
- name: Download built wheels
123126
uses: actions/download-artifact@v2
124127
with:
@@ -324,6 +327,9 @@ jobs:
324327
shell: bash
325328
run: |
326329
python3 -mpip install numpy pytest --no-cache-dir
330+
- name: Install tensordict
331+
run: |
332+
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
327333
- name: Download built wheels
328334
uses: actions/download-artifact@v2
329335
with:

examples/a2c/utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,7 @@ def make_transformed_env_pixels(base_env, env_cfg):
107107
double_to_float_list += [
108108
"reward",
109109
]
110-
double_to_float_list += [
111-
"action",
112-
]
113110
double_to_float_inv_list += ["action"] # DMControl requires double-precision
114-
double_to_float_list += ["observation_vector"]
115-
else:
116-
double_to_float_list += ["observation_vector"]
117111
env.append_transform(
118112
DoubleToFloat(
119113
in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list
@@ -152,9 +146,6 @@ def make_transformed_env_states(base_env, env_cfg):
152146
double_to_float_list += [
153147
"reward",
154148
]
155-
double_to_float_list += [
156-
"action",
157-
]
158149
double_to_float_inv_list += ["action"] # DMControl requires double-precision
159150
double_to_float_list += ["observation_vector"]
160151
else:

examples/dreamer/dreamer_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def make_env_transforms(
119119
if env_library is DMControlEnv:
120120
double_to_float_list += [
121121
"reward",
122-
"action",
123122
]
124123
float_to_double_list += ["action"] # DMControl requires double-precision
125124
env.append_transform(

examples/ppo/ppo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,13 @@ def main(cfg: "DictConfig"): # noqa: F821
7272
# Main loop
7373
r0 = None
7474
l0 = None
75+
frame_skip = cfg.env.frame_skip
76+
mini_batch_size = cfg.loss.mini_batch_size
77+
ppo_epochs = cfg.loss.ppo_epochs
7578
for data in collector:
7679

7780
frames_in_batch = data.numel()
78-
collected_frames += frames_in_batch * cfg.env.frame_skip
81+
collected_frames += frames_in_batch * frame_skip
7982
pbar.update(data.numel())
8083
data_view = data.reshape(-1)
8184

@@ -93,8 +96,8 @@ def main(cfg: "DictConfig"): # noqa: F821
9396
"reward_training", episode_rewards.mean().item(), collected_frames
9497
)
9598

96-
for _ in range(cfg.loss.ppo_epochs):
97-
for _ in range(frames_in_batch // cfg.loss.mini_batch_size):
99+
for _ in range(ppo_epochs):
100+
for _ in range(frames_in_batch // mini_batch_size):
98101

99102
# Get a data batch
100103
batch = data_buffer.sample().to(model_device)

examples/ppo/utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,7 @@ def make_transformed_env_pixels(base_env, env_cfg):
108108
double_to_float_list += [
109109
"reward",
110110
]
111-
double_to_float_list += [
112-
"action",
113-
]
114111
double_to_float_inv_list += ["action"] # DMControl requires double-precision
115-
double_to_float_list += ["observation_vector"]
116-
else:
117-
double_to_float_list += ["observation_vector"]
118112
env.append_transform(
119113
DoubleToFloat(
120114
in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list
@@ -153,9 +147,6 @@ def make_transformed_env_states(base_env, env_cfg):
153147
double_to_float_list += [
154148
"reward",
155149
]
156-
double_to_float_list += [
157-
"action",
158-
]
159150
double_to_float_inv_list += ["action"] # DMControl requires double-precision
160151
double_to_float_list += ["observation_vector"]
161152
else:

test/_utils_internal.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,27 @@ def create_env_fn():
156156
return GymEnv(env_name, frame_skip=frame_skip, device=device)
157157

158158
else:
159-
if env_name == "ALE/Pong-v5":
159+
if env_name == PONG_VERSIONED:
160160

161161
def create_env_fn():
162+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
163+
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
162164
return TransformedEnv(
163-
GymEnv(env_name, frame_skip=frame_skip, device=device),
164-
Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]),
165+
base_env,
166+
Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]),
165167
)
166168

167169
else:
168170

169171
def create_env_fn():
172+
173+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
174+
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
175+
170176
return TransformedEnv(
171-
GymEnv(env_name, frame_skip=frame_skip, device=device),
177+
base_env,
172178
Compose(
173-
ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1),
179+
ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1),
174180
RewardClipping(0, 0.1),
175181
),
176182
)
@@ -179,8 +185,14 @@ def create_env_fn():
179185
env_parallel = ParallelEnv(N, create_env_fn, create_env_kwargs=kwargs)
180186
env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)
181187

188+
for key in env0.observation_spec.keys(True, True):
189+
obs_key = key
190+
break
191+
else:
192+
obs_key = None
193+
182194
if transformed_out:
183-
t_out = get_transform_out(env_name, transformed_in)
195+
t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key)
184196

185197
env0 = TransformedEnv(
186198
env0,
@@ -223,7 +235,7 @@ def _make_multithreaded_env(
223235

224236
torch.manual_seed(0)
225237
multithreaded_kwargs = (
226-
{"frame_skip": frame_skip} if env_name == "ALE/Pong-v5" else {}
238+
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED else {}
227239
)
228240
env_multithread = MultiThreadedEnv(
229241
N,
@@ -233,46 +245,53 @@ def _make_multithreaded_env(
233245
)
234246

235247
if transformed_out:
248+
for key in env_multithread.observation_spec.keys(True, True):
249+
obs_key = key
250+
break
251+
else:
252+
obs_key = None
236253
env_multithread = TransformedEnv(
237254
env_multithread,
238-
get_transform_out(env_name, transformed_in=False)(),
255+
get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(),
239256
)
240257
return env_multithread
241258

242259

243-
def get_transform_out(env_name, transformed_in):
260+
def get_transform_out(env_name, transformed_in, obs_key=None):
244261

245-
if env_name == "ALE/Pong-v5":
262+
if env_name == PONG_VERSIONED:
263+
if obs_key is None:
264+
obs_key = "pixels"
246265

247266
def t_out():
248267
return (
249-
Compose(*[ToTensorImage(), RewardClipping(0, 0.1)])
268+
Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)])
250269
if not transformed_in
251-
else Compose(*[ObservationNorm(in_keys=["pixels"], loc=0, scale=1)])
270+
else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)])
252271
)
253272

254-
elif env_name == "CheetahRun-v1":
273+
elif env_name == HALFCHEETAH_VERSIONED:
274+
if obs_key is None:
275+
obs_key = ("observation", "velocity")
255276

256277
def t_out():
257278
return Compose(
258-
ObservationNorm(
259-
in_keys=[("observation", "velocity")], loc=0.5, scale=1.1
260-
),
279+
ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
261280
RewardClipping(0, 0.1),
262281
)
263282

264283
else:
284+
if obs_key is None:
285+
obs_key = "observation"
265286

266287
def t_out():
267288
return (
268289
Compose(
269-
ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1),
290+
ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
270291
RewardClipping(0, 0.1),
271292
)
272293
if not transformed_in
273-
else Compose(
274-
ObservationNorm(in_keys=["observation"], loc=1.0, scale=1.0)
275-
)
294+
else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0))
276295
)
277296

278297
return t_out

test/test_libs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def test_jumanji_consistency(self, envname, batch_size):
518518
"Acrobot-v1",
519519
CARTPOLE_VERSIONED,
520520
]
521-
ENVPOOL_ATARI_ENVS = [PONG_VERSIONED]
521+
ENVPOOL_ATARI_ENVS = [] # PONG_VERSIONED]
522522
ENVPOOL_GYM_ENVS = ENVPOOL_CLASSIC_CONTROL_ENVS + ENVPOOL_ATARI_ENVS
523523
ENVPOOL_DM_ENVS = ["CheetahRun-v1"]
524524
ENVPOOL_ALL_ENVS = ENVPOOL_GYM_ENVS + ENVPOOL_DM_ENVS
@@ -558,6 +558,7 @@ def test_specs(self, env_name, frame_skip, transformed_out, T=10, N=3):
558558
def test_env_basic_operation(
559559
self, env_name, frame_skip, transformed_out, T=10, N=3
560560
):
561+
torch.manual_seed(0)
561562
env_multithreaded = _make_multithreaded_env(
562563
env_name,
563564
frame_skip,
@@ -737,7 +738,7 @@ def test_multithreaded_env_seed(
737738

738739
# Check that results are different if seed is different
739740
# Skip Pong, since there different actions can lead to the same result
740-
if env_name != "ALE/Pong-v5":
741+
if env_name != PONG_VERSIONED:
741742
env.set_seed(
742743
seed=seed + 10,
743744
)

test/test_rb.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -840,14 +840,10 @@ def test_insert_transform():
840840
def test_smoke_replay_buffer_transform(transform):
841841
rb = ReplayBuffer(transform=transform(in_keys="observation"), batch_size=1)
842842

843+
# td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, [])
843844
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
844845
rb.add(td)
845-
if not isinstance(rb._transform[0], (CatFrames,)):
846-
rb.sample()
847-
else:
848-
with pytest.raises(NotImplementedError):
849-
rb.sample()
850-
return
846+
rb.sample()
851847

852848
rb._transform = mock.MagicMock()
853849
rb._transform.__len__ = lambda *args: 3
@@ -856,7 +852,7 @@ def test_smoke_replay_buffer_transform(transform):
856852

857853

858854
transforms = [
859-
partial(DiscreteActionProjection, num_actions_effective=1, max_actions=1),
855+
partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3),
860856
FiniteTensorDictCheck,
861857
gSDENoise,
862858
PinMemoryTransform,
@@ -865,13 +861,15 @@ def test_smoke_replay_buffer_transform(transform):
865861

866862
@pytest.mark.parametrize("transform", transforms)
867863
def test_smoke_replay_buffer_transform_no_inkeys(transform):
868-
if PinMemoryTransform is PinMemoryTransform and not torch.cuda.is_available():
864+
if transform == PinMemoryTransform and not torch.cuda.is_available():
869865
raise pytest.skip("No CUDA device detected, skipping PinMemory")
870866
rb = ReplayBuffer(
871867
collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1
872868
)
873869

874-
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
870+
action = torch.zeros(3)
871+
action[..., 0] = 1
872+
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, [])
875873
rb.add(td)
876874
rb.sample()
877875

test/test_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
except ImportError:
2323
_has_tb = False
2424

25+
from _utils_internal import PONG_VERSIONED
2526
from tensordict import TensorDict
2627
from torchrl.data import (
2728
LazyMemmapStorage,
@@ -836,7 +837,7 @@ def test_subsampler_state_dict(self):
836837
class TestRecorder:
837838
def _get_args(self):
838839
args = Namespace()
839-
args.env_name = "ALE/Pong-v5"
840+
args.env_name = PONG_VERSIONED
840841
args.env_task = ""
841842
args.grayscale = True
842843
args.env_library = "gym"
@@ -894,7 +895,7 @@ def test_recorder(self, N=8):
894895
},
895896
)
896897
ea.Reload()
897-
img = ea.Images("tmp_ALE/Pong-v5_video")
898+
img = ea.Images(f"tmp_{PONG_VERSIONED}_video")
898899
try:
899900
assert len(img) == N // args.record_interval
900901
break

0 commit comments

Comments
 (0)