Skip to content

Commit 2de55cb

Browse files
authored
[Refactor] Defaults split_trajs to False (#947)
1 parent eb9a37d commit 2de55cb

File tree

10 files changed

+343
-206
lines changed

10 files changed

+343
-206
lines changed

test/test_collector.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,9 @@ def env_fn(seed):
635635
with pytest.raises(AssertionError):
636636
assert_allclose_td(b1, b2)
637637

638-
if num_env == 1:
639-
# rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension
640-
rollout1a = rollout1a.unsqueeze(0)
638+
# if num_env == 1:
639+
# # rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension
640+
# rollout1a = rollout1a.unsqueeze(0)
641641
assert (
642642
rollout1a.batch_size == b1.batch_size
643643
), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}"
@@ -690,12 +690,12 @@ def make_frames_per_batch(frames_per_batch):
690690
data1 = []
691691
for d in collector1:
692692
data1.append(d)
693-
count += d.shape[1]
693+
count += d.shape[-1]
694694
if count > max_frames_per_traj:
695695
break
696696

697-
data1 = torch.cat(data1, 1)
698-
data1 = data1[:, :max_frames_per_traj]
697+
data1 = torch.cat(data1, d.ndim - 1)
698+
data1 = data1[..., :max_frames_per_traj]
699699

700700
collector1.shutdown()
701701
del collector1
@@ -715,12 +715,12 @@ def make_frames_per_batch(frames_per_batch):
715715
data10 = []
716716
for d in collector10:
717717
data10.append(d)
718-
count += d.shape[1]
718+
count += d.shape[-1]
719719
if count > max_frames_per_traj:
720720
break
721721

722-
data10 = torch.cat(data10, 1)
723-
data10 = data10[:, :max_frames_per_traj]
722+
data10 = torch.cat(data10, data1.ndim - 1)
723+
data10 = data10[..., :max_frames_per_traj]
724724

725725
collector10.shutdown()
726726
del collector10
@@ -740,14 +740,14 @@ def make_frames_per_batch(frames_per_batch):
740740
data20 = []
741741
for d in collector20:
742742
data20.append(d)
743-
count += d.shape[1]
743+
count += d.shape[-1]
744744
if count > max_frames_per_traj:
745745
break
746746

747747
collector20.shutdown()
748748
del collector20
749-
data20 = torch.cat(data20, 1)
750-
data20 = data20[:, :max_frames_per_traj]
749+
data20 = torch.cat(data20, data1.ndim - 1)
750+
data20 = data20[..., :max_frames_per_traj]
751751

752752
assert_allclose_td(data1, data20)
753753
assert_allclose_td(data10, data20)
@@ -932,7 +932,10 @@ def make_env():
932932
)
933933
@pytest.mark.parametrize("init_random_frames", [0, 50])
934934
@pytest.mark.parametrize("explicit_spec", [True, False])
935-
def test_collector_output_keys(collector_class, init_random_frames, explicit_spec):
935+
@pytest.mark.parametrize("split_trajs", [True, False])
936+
def test_collector_output_keys(
937+
collector_class, init_random_frames, explicit_spec, split_trajs
938+
):
936939
from torchrl.envs.libs.gym import GymEnv
937940

938941
out_features = 1
@@ -979,6 +982,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
979982
"total_frames": total_frames,
980983
"frames_per_batch": frames_per_batch,
981984
"init_random_frames": init_random_frames,
985+
"split_trajs": split_trajs,
982986
}
983987

984988
if collector_class is not SyncDataCollector:
@@ -995,7 +999,6 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
995999
"collector",
9961000
"hidden1",
9971001
"hidden2",
998-
("collector", "mask"),
9991002
("next", "hidden1"),
10001003
("next", "hidden2"),
10011004
("next", "observation"),
@@ -1005,6 +1008,8 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
10051008
"observation",
10061009
("collector", "traj_ids"),
10071010
}
1011+
if split_trajs:
1012+
keys.add(("collector", "mask"))
10081013
b = next(iter(collector))
10091014

10101015
assert set(b.keys(True)) == keys

test/test_cost.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
341341
actor, gamma=gamma, loss_function="l2", delay_value=delay_value
342342
)
343343

344-
ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
344+
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
345345
ms_td = ms(td.clone())
346346

347347
with _check_td_steady(ms_td):
@@ -351,7 +351,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
351351
with torch.no_grad():
352352
loss = loss_fn(td)
353353
if n == 0:
354-
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
354+
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
355355
_loss = sum([item for _, item in loss.items()])
356356
_loss_ms = sum([item for _, item in loss_ms.items()])
357357
assert (
@@ -635,7 +635,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):
635635
delay_value=delay_value,
636636
)
637637

638-
ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
638+
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
639639
ms_td = ms(td.clone())
640640
with _check_td_steady(ms_td):
641641
loss_ms = loss_fn(ms_td)
@@ -853,7 +853,7 @@ def test_td3_batcher(
853853
delay_actor=delay_actor,
854854
)
855855

856-
ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
856+
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
857857

858858
td_clone = td.clone()
859859
ms_td = ms(td_clone)
@@ -1226,7 +1226,7 @@ def test_sac_batcher(
12261226
**kwargs,
12271227
)
12281228

1229-
ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
1229+
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
12301230

12311231
td_clone = td.clone()
12321232
ms_td = ms(td_clone)
@@ -1717,7 +1717,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):
17171717
delay_qvalue=delay_qvalue,
17181718
)
17191719

1720-
ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
1720+
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
17211721

17221722
td_clone = td.clone()
17231723
ms_td = ms(td_clone)

test/test_postprocs.py

Lines changed: 153 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,24 @@ def test_multistep(n, key, device, T=11):
6363
)
6464

6565
# assert that done at last step is similar to unterminated traj
66-
assert (ms_tensordict.get("gamma")[4] == ms_tensordict.get("gamma")[0]).all()
67-
assert (
68-
ms_tensordict.get(("next", key))[4] == ms_tensordict.get(("next", key))[0]
69-
).all()
70-
assert (
71-
ms_tensordict.get("steps_to_next_obs")[4]
72-
== ms_tensordict.get("steps_to_next_obs")[0]
73-
).all()
66+
torch.testing.assert_close(
67+
ms_tensordict.get("gamma")[4], ms_tensordict.get("gamma")[0]
68+
)
69+
torch.testing.assert_close(
70+
ms_tensordict.get(("next", key))[4], ms_tensordict.get(("next", key))[0]
71+
)
72+
torch.testing.assert_close(
73+
ms_tensordict.get("steps_to_next_obs")[4],
74+
ms_tensordict.get("steps_to_next_obs")[0],
75+
)
7476

7577
# check that next obs is properly replaced, or that it is terminated
76-
next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps_max) :]
77-
true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps_max)]
78+
next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps) :]
79+
true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps)]
7880
terminated = ~ms_tensordict.get("nonterminal")
79-
assert ((next_obs == true_next_obs) | terminated[:, (1 + ms.n_steps_max) :]).all()
81+
assert (
82+
(next_obs == true_next_obs).all(-1) | terminated[:, (1 + ms.n_steps) :]
83+
).all()
8084

8185
# test gamma computation
8286
torch.testing.assert_close(
@@ -90,10 +94,144 @@ def test_multistep(n, key, device, T=11):
9094
!= ms_tensordict.get(("next", "original_reward"))
9195
).any()
9296
else:
93-
assert (
94-
ms_tensordict.get(("next", "reward"))
95-
== ms_tensordict.get(("next", "original_reward"))
96-
).all()
97+
torch.testing.assert_close(
98+
ms_tensordict.get(("next", "reward")),
99+
ms_tensordict.get(("next", "original_reward")),
100+
)
101+
102+
103+
@pytest.mark.parametrize("device", get_available_devices())
104+
@pytest.mark.parametrize(
105+
"batch_size",
106+
[
107+
[
108+
4,
109+
],
110+
[],
111+
[
112+
1,
113+
],
114+
[2, 3],
115+
],
116+
)
117+
@pytest.mark.parametrize(
118+
"T",
119+
[
120+
10,
121+
1,
122+
2,
123+
],
124+
)
125+
@pytest.mark.parametrize(
126+
"obs_dim",
127+
[
128+
[
129+
1,
130+
],
131+
[],
132+
],
133+
)
134+
@pytest.mark.parametrize("unsq_reward", [True, False])
135+
@pytest.mark.parametrize("last_done", [True, False])
136+
@pytest.mark.parametrize("n_steps", [3, 1, 0])
137+
def test_mutistep_cattrajs(
138+
batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps
139+
):
140+
# tests multi-step in the presence of consecutive trajectories.
141+
obs = torch.randn(*batch_size, T + 1, *obs_dim)
142+
reward = torch.rand(*batch_size, T)
143+
action = torch.rand(*batch_size, T)
144+
done = torch.zeros(*batch_size, T + 1, dtype=torch.bool)
145+
done[..., T // 2] = 1
146+
if last_done:
147+
done[..., -1] = 1
148+
if unsq_reward:
149+
reward = reward.unsqueeze(-1)
150+
done = done.unsqueeze(-1)
151+
152+
td = TensorDict(
153+
{
154+
"obs": obs[..., :-1] if not obs_dim else obs[..., :-1, :],
155+
"action": action,
156+
"done": done[..., :-1] if not unsq_reward else done[..., :-1, :],
157+
"next": {
158+
"obs": obs[..., 1:] if not obs_dim else obs[..., 1:, :],
159+
"done": done[..., 1:] if not unsq_reward else done[..., 1:, :],
160+
"reward": reward,
161+
},
162+
},
163+
batch_size=[*batch_size, T],
164+
device=device,
165+
)
166+
ms = MultiStep(0.98, n_steps)
167+
tdm = ms(td)
168+
if n_steps == 0:
169+
# n_steps = 0 has no effect
170+
for k in td["next"].keys():
171+
assert (tdm["next", k] == td["next", k]).all()
172+
else:
173+
next_obs = []
174+
obs = td["next", "obs"]
175+
done = td["next", "done"]
176+
if obs_dim:
177+
obs = obs.squeeze(-1)
178+
if unsq_reward:
179+
done = done.squeeze(-1)
180+
for t in range(T):
181+
idx = t + n_steps
182+
while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1:
183+
idx = idx - 1
184+
next_obs.append(obs[..., idx])
185+
true_next_obs = tdm.get(("next", "obs"))
186+
if obs_dim:
187+
true_next_obs = true_next_obs.squeeze(-1)
188+
next_obs = torch.stack(next_obs, -1)
189+
assert (next_obs == true_next_obs).all()
190+
191+
192+
@pytest.mark.parametrize("unsq_reward", [True, False])
193+
def test_unusual_done(unsq_reward):
194+
batch_size = [10, 3]
195+
T = 10
196+
obs_dim = [
197+
1,
198+
]
199+
last_done = True
200+
device = torch.device("cpu")
201+
n_steps = 3
202+
203+
obs = torch.randn(*batch_size, T + 1, 5, *obs_dim)
204+
reward = torch.rand(*batch_size, T, 5)
205+
action = torch.rand(*batch_size, T, 5)
206+
done = torch.zeros(*batch_size, T + 1, 5, dtype=torch.bool)
207+
done[..., T // 2, :] = 1
208+
if last_done:
209+
done[..., -1, :] = 1
210+
if unsq_reward:
211+
reward = reward.unsqueeze(-1)
212+
done = done.unsqueeze(-1)
213+
214+
td = TensorDict(
215+
{
216+
"obs": obs[..., :-1, :] if not obs_dim else obs[..., :-1, :, :],
217+
"action": action,
218+
"done": done[..., :-1, :] if not unsq_reward else done[..., :-1, :, :],
219+
"next": {
220+
"obs": obs[..., 1:, :] if not obs_dim else obs[..., 1:, :, :],
221+
"done": done[..., 1:, :] if not unsq_reward else done[..., 1:, :, :],
222+
"reward": reward,
223+
},
224+
},
225+
batch_size=[*batch_size, T],
226+
device=device,
227+
)
228+
ms = MultiStep(0.98, n_steps)
229+
if unsq_reward:
230+
with pytest.raises(RuntimeError, match="tensordict shape must be compatible"):
231+
_ = ms(td)
232+
else:
233+
# we just check that it runs
234+
_ = ms(td)
97235

98236

99237
class TestSplits:

0 commit comments

Comments
 (0)