Skip to content

Commit 31a6db1

Browse files
authored
[Refactor] Better init for CatFrames buffers + removing default init values (#874)
1 parent 2fd4632 commit 31a6db1

File tree

8 files changed

+164
-111
lines changed

8 files changed

+164
-111
lines changed

examples/dreamer/dreamer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def make_env_transforms(
9090
if cfg.grayscale:
9191
env.append_transform(GrayScale())
9292
env.append_transform(FlattenObservation(0, -3))
93-
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"]))
93+
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3))
9494
if stats is None:
9595
obs_stats = {
9696
"loc": torch.zeros(env.observation_spec["pixels"].shape),

test/test_helpers.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,6 @@ def dreamer_constructor_fixture():
8989
sys.path.pop()
9090

9191

92-
def _assert_keys_match(td, expeceted_keys):
93-
td_keys = list(td.keys())
94-
d = set(td_keys) - set(expeceted_keys)
95-
assert len(d) == 0, f"{d} is in tensordict but unexpected: {td.keys()}"
96-
d = set(expeceted_keys) - set(td_keys)
97-
assert len(d) == 0, f"{d} is expected but not in tensordict: {td.keys()}"
98-
assert len(td_keys) == len(expeceted_keys)
99-
100-
10192
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
10293
@pytest.mark.skipif(not _has_tv, reason="No torchvision library found")
10394
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@@ -152,16 +143,20 @@ def test_dqn_maker(
152143
else:
153144
actor(td)
154145

155-
expected_keys = ["done", "action", "action_value"]
146+
expected_keys = [
147+
"done",
148+
"action",
149+
"action_value",
150+
]
156151
if from_pixels:
157-
expected_keys += ["pixels", "pixels_orig"]
152+
expected_keys += ["pixels", "pixels_orig", "_reset"]
158153
else:
159154
expected_keys += ["observation_orig", "observation_vector"]
160155

161156
if not distributional:
162157
expected_keys += ["chosen_action_value"]
163158
try:
164-
_assert_keys_match(td, expected_keys)
159+
assert set(td.keys()) == set(expected_keys)
165160
except AssertionError:
166161
proof_environment.close()
167162
raise
@@ -217,15 +212,15 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):
217212
actor(td)
218213
expected_keys = ["done", "action", "param"]
219214
if from_pixels:
220-
expected_keys += ["pixels", "hidden", "pixels_orig"]
215+
expected_keys += ["pixels", "hidden", "pixels_orig", "_reset"]
221216
else:
222217
expected_keys += ["observation_vector", "observation_orig"]
223218

224219
if cfg.gSDE:
225220
expected_keys += ["scale", "loc", "_eps_gSDE"]
226221

227222
try:
228-
_assert_keys_match(td, expected_keys)
223+
assert set(td.keys()) == set(expected_keys)
229224
except AssertionError:
230225
proof_environment.close()
231226
raise
@@ -245,7 +240,7 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):
245240
value(td)
246241
expected_keys += ["state_action_value"]
247242
try:
248-
_assert_keys_match(td, expected_keys)
243+
assert set(td.keys()) == set(expected_keys)
249244
except AssertionError:
250245
proof_environment.close()
251246
raise
@@ -359,8 +354,12 @@ def test_ppo_maker(
359354
else:
360355
actor(td_clone)
361356

357+
if from_pixels:
358+
# for CatFrames
359+
expected_keys += ["_reset"]
360+
362361
try:
363-
_assert_keys_match(td_clone, expected_keys)
362+
assert set(td_clone.keys()) == set(expected_keys)
364363
except AssertionError:
365364
proof_environment.close()
366365
raise
@@ -386,6 +385,9 @@ def test_ppo_maker(
386385
"pixels_orig" if len(from_pixels) else "observation_orig",
387386
"state_value",
388387
]
388+
if from_pixels:
389+
# for CatFrames
390+
expected_keys += ["_reset"]
389391
if shared_mapping:
390392
expected_keys += ["hidden"]
391393
if len(gsde):
@@ -398,7 +400,7 @@ def test_ppo_maker(
398400
else:
399401
value(td_clone)
400402
try:
401-
_assert_keys_match(td_clone, expected_keys)
403+
assert set(td_clone.keys()) == set(expected_keys)
402404
except AssertionError:
403405
proof_environment.close()
404406
raise
@@ -495,6 +497,9 @@ def test_a2c_maker(
495497
"action",
496498
"sample_log_prob",
497499
]
500+
if from_pixels:
501+
# for CatFrames
502+
expected_keys += ["_reset"]
498503
if action_space == "continuous":
499504
expected_keys += ["loc", "scale"]
500505
else:
@@ -514,7 +519,7 @@ def test_a2c_maker(
514519
actor(td_clone)
515520

516521
try:
517-
_assert_keys_match(td_clone, expected_keys)
522+
assert set(td_clone.keys()) == set(expected_keys)
518523
except AssertionError:
519524
proof_environment.close()
520525
raise
@@ -540,6 +545,9 @@ def test_a2c_maker(
540545
"pixels_orig" if len(from_pixels) else "observation_orig",
541546
"state_value",
542547
]
548+
if from_pixels:
549+
# for CatFrames
550+
expected_keys += ["_reset"]
543551
if shared_mapping:
544552
expected_keys += ["hidden"]
545553
if len(gsde):
@@ -552,7 +560,7 @@ def test_a2c_maker(
552560
else:
553561
value(td_clone)
554562
try:
555-
_assert_keys_match(td_clone, expected_keys)
563+
assert set(td_clone.keys()) == set(expected_keys)
556564
except AssertionError:
557565
proof_environment.close()
558566
raise
@@ -631,6 +639,9 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
631639
"loc",
632640
"scale",
633641
]
642+
if from_pixels:
643+
# for CatFrames
644+
expected_keys += ["_reset"]
634645
if len(gsde):
635646
expected_keys += ["_eps_gSDE"]
636647

@@ -643,7 +654,7 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
643654
torch.testing.assert_close(td_clone.get("action"), tsf_loc)
644655

645656
try:
646-
_assert_keys_match(td_clone, expected_keys)
657+
assert set(td_clone.keys()) == set(expected_keys)
647658
except AssertionError:
648659
proof_environment.close()
649660
raise
@@ -667,7 +678,7 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
667678
expected_keys += ["_eps_gSDE"]
668679

669680
try:
670-
_assert_keys_match(td_clone, expected_keys)
681+
assert set(td_clone.keys()) == set(expected_keys)
671682
except AssertionError:
672683
proof_environment.close()
673684
raise
@@ -687,7 +698,7 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
687698
expected_keys += ["_eps_gSDE"]
688699

689700
try:
690-
_assert_keys_match(td, expected_keys)
701+
assert set(td.keys()) == set(expected_keys)
691702
except AssertionError:
692703
proof_environment.close()
693704
raise
@@ -756,12 +767,12 @@ def test_redq_make(device, from_pixels, gsde, exploration):
756767
if len(gsde):
757768
expected_keys += ["_eps_gSDE"]
758769
if from_pixels:
759-
expected_keys += ["hidden", "pixels", "pixels_orig"]
770+
expected_keys += ["hidden", "pixels", "pixels_orig", "_reset"]
760771
else:
761772
expected_keys += ["observation_vector", "observation_orig"]
762773

763774
try:
764-
_assert_keys_match(td, expected_keys)
775+
assert set(td.keys()) == set(expected_keys)
765776
except AssertionError:
766777
proof_environment.close()
767778
raise
@@ -786,11 +797,11 @@ def test_redq_make(device, from_pixels, gsde, exploration):
786797
if len(gsde):
787798
expected_keys += ["_eps_gSDE"]
788799
if from_pixels:
789-
expected_keys += ["hidden", "pixels", "pixels_orig"]
800+
expected_keys += ["hidden", "pixels", "pixels_orig", "_reset"]
790801
else:
791802
expected_keys += ["observation_vector", "observation_orig"]
792803
try:
793-
_assert_keys_match(td, expected_keys)
804+
assert set(td.keys()) == set(expected_keys)
794805
except AssertionError:
795806
proof_environment.close()
796807
raise
@@ -861,6 +872,7 @@ def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture
861872
"state",
862873
("next", "reco_pixels"),
863874
"next",
875+
"_reset",
864876
}
865877
assert set(out.keys(True)) == expected_keys
866878

test/test_rb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def test_insert_transform():
798798
pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
799799
GrayScale,
800800
pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"),
801-
CatFrames,
801+
pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"),
802802
pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"),
803803
DoubleToFloat,
804804
VecNorm,

test/test_transforms.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,6 @@ def test_time_max_pool(self, T, seq_len, device):
878878
tensor_list.append(torch.rand(batch, nodes).to(device))
879879
max_vals, _ = torch.max(torch.stack(tensor_list[-T:]), dim=0)
880880

881-
print(f"max vals: {max_vals}")
882-
883881
for i in range(seq_len):
884882
env_td = TensorDict(
885883
{
@@ -946,7 +944,11 @@ def test_totensorimage(self, keys, batch, device):
946944
@pytest.mark.parametrize("device", get_available_devices())
947945
def test_compose(self, keys, batch, device, nchannels=1, N=4):
948946
torch.manual_seed(0)
949-
t1 = CatFrames(in_keys=keys, N=4)
947+
t1 = CatFrames(
948+
in_keys=keys,
949+
N=4,
950+
dim=-3,
951+
)
950952
t2 = FiniteTensorDictCheck()
951953
compose = Compose(t1, t2)
952954
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
@@ -1287,7 +1289,11 @@ def test_catframes_transform_observation_spec(self):
12871289
key1 = "first key"
12881290
key2 = "second key"
12891291
keys = [key1, key2]
1290-
cat_frames = CatFrames(N=N, in_keys=keys)
1292+
cat_frames = CatFrames(
1293+
N=N,
1294+
in_keys=keys,
1295+
dim=-3,
1296+
)
12911297
mins = [0, 0.5]
12921298
maxes = [0.5, 1]
12931299
observation_spec = CompositeSpec(
@@ -1321,31 +1327,50 @@ def test_catframes_transform_observation_spec(self):
13211327
)
13221328

13231329
@pytest.mark.parametrize("device", get_available_devices())
1330+
@pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)])
13241331
@pytest.mark.parametrize("d", range(1, 4))
1325-
def test_catframes_buffer_check_latest_frame(self, device, d):
1332+
@pytest.mark.parametrize("dim", [-3, -2, 1])
1333+
@pytest.mark.parametrize("N", [2, 4])
1334+
def test_catframes_buffer_check_latest_frame(self, device, d, batch_size, dim, N):
13261335
key1 = "first key"
13271336
key2 = "second key"
1328-
N = 4
13291337
keys = [key1, key2]
1330-
key1_tensor = torch.ones(1, d, 3, 3, device=device) * 2
1331-
key2_tensor = torch.ones(1, d, 3, 3, device=device)
1338+
extra_d = (3,) * (-dim - 1)
1339+
key1_tensor = torch.ones(*batch_size, d, *extra_d, device=device) * 2
1340+
key2_tensor = torch.ones(*batch_size, d, *extra_d, device=device)
13321341
key_tensors = [key1_tensor, key2_tensor]
1333-
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
1334-
cat_frames = CatFrames(N=N, in_keys=keys)
1342+
td = TensorDict(dict(zip(keys, key_tensors)), batch_size, device=device)
1343+
if dim > 0:
1344+
with pytest.raises(
1345+
ValueError, match="dim must be > 0 to accomodate for tensordict"
1346+
):
1347+
cat_frames = CatFrames(N=N, in_keys=keys, dim=dim)
1348+
return
1349+
cat_frames = CatFrames(N=N, in_keys=keys, dim=dim)
13351350

13361351
tdclone = cat_frames(td.clone())
13371352
latest_frame = tdclone.get(key2)
13381353

1339-
assert latest_frame.shape[1] == N * d
1340-
assert (latest_frame[0, :-d] == 0).all()
1341-
assert (latest_frame[0, -d:] == 1).all()
1354+
assert latest_frame.shape[dim] == N * d
1355+
slices = (slice(None),) * (-dim - 1)
1356+
index1 = (Ellipsis, slice(None, -d), *slices)
1357+
index2 = (Ellipsis, slice(-d, None), *slices)
1358+
assert (latest_frame[index1] == 0).all()
1359+
assert (latest_frame[index2] == 1).all()
1360+
v1 = latest_frame[index1]
13421361

13431362
tdclone = cat_frames(td.clone())
13441363
latest_frame = tdclone.get(key2)
13451364

1346-
assert latest_frame.shape[1] == N * d
1347-
assert (latest_frame[0, : -2 * d] == 0).all()
1348-
assert (latest_frame[0, -2 * d :] == 1).all()
1365+
assert latest_frame.shape[dim] == N * d
1366+
index1 = (Ellipsis, slice(None, -2 * d), *slices)
1367+
index2 = (Ellipsis, slice(-2 * d, None), *slices)
1368+
assert (latest_frame[index1] == 0).all()
1369+
assert (latest_frame[index2] == 1).all()
1370+
v2 = latest_frame[index1]
1371+
1372+
# we don't want the same tensor to be returned twice, but they're all copies of the same buffer
1373+
assert v1 is not v2
13491374

13501375
@pytest.mark.parametrize("device", get_available_devices())
13511376
def test_catframes_reset(self, device):
@@ -1357,19 +1382,20 @@ def test_catframes_reset(self, device):
13571382
key2_tensor = torch.randn(1, 1, 3, 3, device=device)
13581383
key_tensors = [key1_tensor, key2_tensor]
13591384
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
1360-
cat_frames = CatFrames(N=N, in_keys=keys)
1385+
cat_frames = CatFrames(N=N, in_keys=keys, dim=-3)
13611386

1362-
cat_frames(td)
1387+
cat_frames(td.clone())
13631388
buffer = getattr(cat_frames, f"_cat_buffers_{key1}")
13641389

1365-
passed_back_td = cat_frames.reset(td)
1390+
tdc = td.clone()
1391+
passed_back_td = cat_frames.reset(tdc)
1392+
assert "_reset" in tdc.keys()
13661393

1367-
assert td is passed_back_td
1368-
assert (0 == buffer).all()
1394+
assert tdc is passed_back_td
1395+
assert (buffer == 0).all()
13691396

1370-
_ = cat_frames._call(td)
1371-
assert (0 == buffer[..., :-1, :, :]).all()
1372-
assert (0 != buffer[..., -1:, :, :]).all()
1397+
_ = cat_frames._call(tdc)
1398+
assert (buffer != 0).all()
13731399

13741400
@pytest.mark.parametrize("device", get_available_devices())
13751401
def test_finitetensordictcheck(self, device):
@@ -1691,7 +1717,7 @@ def test_append(self):
16911717
(key,) = itertools.islice(obs_spec.keys(), 1)
16921718

16931719
env = TransformedEnv(env)
1694-
env.append_transform(CatFrames(N=4, cat_dim=-1, in_keys=[key]))
1720+
env.append_transform(CatFrames(N=4, dim=-1, in_keys=[key]))
16951721
assert isinstance(env.transform, Compose)
16961722
assert len(env.transform) == 1
16971723
obs_spec = env.observation_spec
@@ -1715,7 +1741,7 @@ def test_insert(self):
17151741
assert env._observation_spec is not None
17161742
assert env._reward_spec is not None
17171743

1718-
env.insert_transform(0, CatFrames(N=4, cat_dim=-1, in_keys=[key]))
1744+
env.insert_transform(0, CatFrames(N=4, dim=-1, in_keys=[key]))
17191745

17201746
# transformed envs do not have spec after insert -- they need to be computed
17211747
assert env._input_spec is None
@@ -1762,7 +1788,7 @@ def test_insert(self):
17621788
assert env._observation_spec is None
17631789
assert env._reward_spec is None
17641790

1765-
env.insert_transform(-5, CatFrames(N=4, cat_dim=-1, in_keys=[key]))
1791+
env.insert_transform(-5, CatFrames(N=4, dim=-1, in_keys=[key]))
17661792
assert isinstance(env.transform, Compose)
17671793
assert len(env.transform) == 6
17681794

@@ -2441,7 +2467,7 @@ def test_select(self):
24412467
pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
24422468
GrayScale,
24432469
ObservationNorm,
2444-
CatFrames,
2470+
pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"),
24452471
pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"),
24462472
FiniteTensorDictCheck,
24472473
DoubleToFloat,

0 commit comments

Comments
 (0)