Skip to content

Commit 450a380

Browse files
authored
[BugFix] Avoid collision of "step_count" key from transform and collector (#868)
1 parent b07fd35 commit 450a380

File tree

15 files changed

+194
-139
lines changed

15 files changed

+194
-139
lines changed

examples/td3/td3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ def main(cfg: "DictConfig"): # noqa: F821
273273
pbar.update(tensordict.numel())
274274

275275
# extend the replay buffer with the new data
276-
if "mask" in tensordict.keys():
276+
if ("collector", "mask") in tensordict.keys(True):
277277
# if multi-step, a mask is present to help filter padded values
278-
current_frames = tensordict["mask"].sum()
279-
tensordict = tensordict[tensordict.get("mask").squeeze(-1)]
278+
current_frames = tensordict["collector", "mask"].sum()
279+
tensordict = tensordict[tensordict.get(("collector", "mask")).squeeze(-1)]
280280
else:
281281
tensordict = tensordict.view(-1)
282282
current_frames = tensordict.numel()

test/test_collector.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def make_env():
316316
)
317317
for _data in collector:
318318
continue
319-
steps = _data["step_count"][..., 1:]
319+
steps = _data["collector", "step_count"][..., 1:]
320320
done = _data["done"][..., :-1, :].squeeze(-1)
321321
# we don't want just one done
322322
assert done.sum() > 3
@@ -375,7 +375,7 @@ def make_env(seed):
375375
break
376376

377377
assert (d["done"].sum(-2) >= 1).all()
378-
assert torch.unique(d["traj_ids"], dim=-1).shape[-1] == 1
378+
assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1
379379

380380
del collector
381381

@@ -426,12 +426,15 @@ def make_env(seed):
426426
break
427427

428428
assert d.ndimension() == 2
429-
assert d["mask"].shape == d.shape
430-
assert d["step_count"].shape == d.shape
431-
assert d["traj_ids"].shape == d.shape
429+
assert d["collector", "mask"].shape == d.shape
430+
assert d["collector", "step_count"].shape == d.shape
431+
assert d["collector", "traj_ids"].shape == d.shape
432432
for traj in d.unbind(0):
433-
assert traj["traj_ids"].unique().numel() == 1
434-
assert (traj["step_count"][1:] - traj["step_count"][:-1] == 1).all()
433+
assert traj["collector", "traj_ids"].unique().numel() == 1
434+
assert (
435+
traj["collector", "step_count"][1:] - traj["collector", "step_count"][:-1]
436+
== 1
437+
).all()
435438

436439
del collector
437440

@@ -986,17 +989,18 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
986989
keys = {
987990
"action",
988991
"done",
992+
"collector",
989993
"hidden1",
990994
"hidden2",
991-
"mask",
995+
("collector", "mask"),
992996
("next", "hidden1"),
993997
("next", "hidden2"),
994998
("next", "observation"),
995999
"next",
9961000
"observation",
9971001
"reward",
998-
"step_count",
999-
"traj_ids",
1002+
("collector", "step_count"),
1003+
("collector", "traj_ids"),
10001004
}
10011005
b = next(iter(collector))
10021006

test/test_cost.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import re
88
from copy import deepcopy
99

10+
from packaging import version as pack_version
11+
1012
_has_functorch = True
1113
try:
1214
import functorch as ft # noqa
@@ -273,7 +275,7 @@ def _create_seq_mock_data_dqn(
273275
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
274276
},
275277
"done": done,
276-
"mask": mask,
278+
"collector": {"mask": mask},
277279
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
278280
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
279281
"action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
@@ -507,7 +509,7 @@ def _create_seq_mock_data_ddpg(
507509
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
508510
},
509511
"done": done,
510-
"mask": mask,
512+
"collector": {"mask": mask},
511513
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
512514
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
513515
},
@@ -735,7 +737,7 @@ def _create_seq_mock_data_td3(
735737
"observation": obs * mask.to(obs.dtype),
736738
"next": {"observation": next_obs * mask.to(obs.dtype)},
737739
"done": done,
738-
"mask": mask,
740+
"collector": {"mask": mask},
739741
"reward": reward * mask.to(obs.dtype),
740742
"action": action * mask.to(obs.dtype),
741743
},
@@ -1012,7 +1014,7 @@ def _create_seq_mock_data_sac(
10121014
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
10131015
},
10141016
"done": done,
1015-
"mask": mask,
1017+
"collector": {"mask": mask},
10161018
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
10171019
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
10181020
},
@@ -1441,7 +1443,7 @@ def _create_seq_mock_data_redq(
14411443
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
14421444
},
14431445
"done": done,
1444-
"mask": mask,
1446+
"collector": {"mask": mask},
14451447
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
14461448
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
14471449
},
@@ -1880,7 +1882,7 @@ def _create_seq_mock_data_ppo(
18801882
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
18811883
},
18821884
"done": done,
1883-
"mask": mask,
1885+
"collector": {"mask": mask},
18841886
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
18851887
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
18861888
"sample_log_prob": (torch.randn_like(action[..., 1]) / 10).masked_fill_(
@@ -2035,6 +2037,8 @@ def test_ppo_shared(self, loss_class, device, advantage):
20352037
@pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda"))
20362038
@pytest.mark.parametrize("device", get_available_devices())
20372039
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
2040+
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
2041+
raise pytest.skip("make_functional_with_buffers needs to be changed")
20382042
torch.manual_seed(self.seed)
20392043
td = self._create_seq_mock_data_ppo(device=device)
20402044

@@ -2153,7 +2157,7 @@ def _create_seq_mock_data_a2c(
21532157
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
21542158
},
21552159
"done": done,
2156-
"mask": mask,
2160+
"collector": {"mask": mask},
21572161
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
21582162
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
21592163
"sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_(
@@ -2245,6 +2249,8 @@ def test_a2c(self, device, gradient_mode, advantage):
22452249
@pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda"))
22462250
@pytest.mark.parametrize("device", get_available_devices())
22472251
def test_a2c_diff(self, device, gradient_mode, advantage):
2252+
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
2253+
raise pytest.skip("make_functional_with_buffers needs to be changed")
22482254
torch.manual_seed(self.seed)
22492255
td = self._create_seq_mock_data_a2c(device=device)
22502256

test/test_postprocs.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_multistep(n, key, device, T=11):
4141
"done": done,
4242
"reward": torch.randn(1, T, 1, device=device).expand(b, T, 1)
4343
* mask.to(torch.float),
44-
"mask": mask,
44+
"collector": {"mask": mask},
4545
},
4646
batch_size=(b, T),
4747
).to(device)
@@ -98,28 +98,28 @@ def create_fake_trajs(
9898
traj_len=200,
9999
):
100100
traj_ids = torch.arange(num_workers)
101-
steps_count = torch.zeros(num_workers)
101+
step_count = torch.zeros(num_workers)
102102
workers = torch.arange(num_workers)
103103

104104
out = []
105105
for _ in range(traj_len):
106-
done = steps_count == traj_ids # traj_id 0 has 0 steps, 1 has 1 step etc.
106+
done = step_count == traj_ids # traj_id 0 has 0 steps, 1 has 1 step etc.
107107

108108
td = TensorDict(
109109
source={
110-
"traj_ids": traj_ids,
110+
("collector", "traj_ids"): traj_ids,
111111
"a": traj_ids.clone().unsqueeze(-1),
112-
"steps_count": steps_count,
112+
("collector", "step_count"): step_count,
113113
"workers": workers,
114114
"done": done.unsqueeze(-1),
115115
},
116116
batch_size=[num_workers],
117117
)
118118
out.append(td.clone())
119-
steps_count += 1
119+
step_count += 1
120120

121121
traj_ids[done] = traj_ids.max() + torch.arange(1, done.sum() + 1)
122-
steps_count[done] = 0
122+
step_count[done] = 0
123123

124124
out = torch.stack(out, 1).contiguous()
125125
return out
@@ -132,22 +132,29 @@ def test_splits(self, num_workers, traj_len):
132132
assert trajs.shape[0] == num_workers
133133
assert trajs.shape[1] == traj_len
134134
split_trajs = split_trajectories(trajs)
135-
assert split_trajs.shape[0] == split_trajs.get("traj_ids").max() + 1
136-
assert split_trajs.shape[1] == split_trajs.get("steps_count").max() + 1
135+
assert (
136+
split_trajs.shape[0] == split_trajs.get(("collector", "traj_ids")).max() + 1
137+
)
138+
assert (
139+
split_trajs.shape[1]
140+
== split_trajs.get(("collector", "step_count")).max() + 1
141+
)
137142

138-
assert split_trajs.get("mask").sum() == num_workers * traj_len
143+
assert split_trajs.get(("collector", "mask")).sum() == num_workers * traj_len
139144

140145
assert split_trajs.get("done").sum(1).max() == 1
141-
out_mask = split_trajs[split_trajs.get("mask")]
146+
out_mask = split_trajs[split_trajs.get(("collector", "mask"))]
142147
for i in range(split_trajs.shape[0]):
143-
traj_id_split = split_trajs[i].get("traj_ids")[split_trajs[i].get("mask")]
148+
traj_id_split = split_trajs[i].get(("collector", "traj_ids"))[
149+
split_trajs[i].get(("collector", "mask"))
150+
]
144151
assert 1 == len(traj_id_split.unique())
145152

146153
for w in range(num_workers):
147154
assert (out_mask.get("workers") == w).sum() == traj_len
148155
# Assert that either the chain is not done XOR if it is it must have the desired length (equal to traj id by design)
149-
for i in range(split_trajs.get("traj_ids").max()):
150-
idx_traj_id = out_mask.get("traj_ids") == i
156+
for i in range(split_trajs.get(("collector", "traj_ids")).max()):
157+
idx_traj_id = out_mask.get(("collector", "traj_ids")) == i
151158
# (!=) == (xor)
152159
c1 = (idx_traj_id.sum() - 1 == i) and (
153160
out_mask.get("done")[idx_traj_id].sum() == 1
@@ -162,8 +169,8 @@ def test_splits(self, num_workers, traj_len):
162169
)
163170

164171
assert (
165-
split_trajs.get("traj_ids").unique().numel()
166-
== split_trajs.get("traj_ids").max() + 1
172+
split_trajs.get(("collector", "traj_ids")).unique().numel()
173+
== split_trajs.get(("collector", "traj_ids")).max() + 1
167174
)
168175

169176

test/test_trainer.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -762,14 +762,14 @@ def test_masking():
762762
batch = 10
763763
td = TensorDict(
764764
{
765-
"mask": torch.zeros(batch, dtype=torch.bool).bernoulli_(),
765+
("collector", "mask"): torch.zeros(batch, dtype=torch.bool).bernoulli_(),
766766
"tensor": torch.randn(batch, 51),
767767
},
768768
[batch],
769769
)
770770
td_out = trainer._process_batch_hook(td)
771-
assert td_out.shape[0] == td.get("mask").sum()
772-
assert (td["tensor"][td["mask"]] == td_out["tensor"]).all()
771+
assert td_out.shape[0] == td.get(("collector", "mask")).sum()
772+
assert (td["tensor"][td[("collector", "mask")]] == td_out["tensor"]).all()
773773

774774

775775
class TestSubSampler:
@@ -989,10 +989,13 @@ def test_countframes(self):
989989
count_frames = CountFramesLog(frame_skip=frame_skip)
990990
count_frames.register(trainer)
991991
td = TensorDict(
992-
{"mask": torch.zeros(batch, dtype=torch.bool).bernoulli_()}, [batch]
992+
{("collector", "mask"): torch.zeros(batch, dtype=torch.bool).bernoulli_()},
993+
[batch],
993994
)
994995
trainer._pre_steps_log_hook(td)
995-
assert count_frames.frame_count == td.get("mask").sum() * frame_skip
996+
assert (
997+
count_frames.frame_count == td.get(("collector", "mask")).sum() * frame_skip
998+
)
996999

9971000
@pytest.mark.parametrize(
9981001
"backend",
@@ -1037,13 +1040,21 @@ def _make_countframe_and_trainer(tmpdirname):
10371040
with tempfile.TemporaryDirectory() as tmpdirname, tempfile.TemporaryDirectory() as tmpdirname2:
10381041
trainer, count_frames, file = _make_countframe_and_trainer(tmpdirname)
10391042
td = TensorDict(
1040-
{"mask": torch.zeros(batch, dtype=torch.bool).bernoulli_()}, [batch]
1043+
{
1044+
("collector", "mask"): torch.zeros(
1045+
batch, dtype=torch.bool
1046+
).bernoulli_()
1047+
},
1048+
[batch],
10411049
)
10421050
trainer._pre_steps_log_hook(td)
10431051
trainer.save_trainer(True)
10441052
trainer2, count_frames2, _ = _make_countframe_and_trainer(tmpdirname2)
10451053
trainer2.load_from_file(file)
1046-
assert count_frames2.frame_count == td.get("mask").sum() * frame_skip
1054+
assert (
1055+
count_frames2.frame_count
1056+
== td.get(("collector", "mask")).sum() * frame_skip
1057+
)
10471058
assert state_dict_has_been_called[0]
10481059
assert load_state_dict_has_been_called[0]
10491060
CountFramesLog.state_dict = CountFramesLog_state_dict

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,7 @@ def test_step_counter(self, max_steps, device, batch, reset_workers):
18321832
while max_steps is None or i < max_steps:
18331833
step_counter._step(td)
18341834
i += 1
1835-
assert torch.all(td.get("step_count") == i)
1835+
assert torch.all(td.get("step_count") == i), (td.get("step_count"), i)
18361836
if max_steps is None:
18371837
break
18381838
if max_steps is not None:

0 commit comments

Comments
 (0)