Skip to content

Commit 31bd542

Browse files
authored
[Test] Fix error catches (#2982)
1 parent 8dcf987 commit 31bd542

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

test/test_actors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
7979
if log_prob_key:
8080
assert td_out[log_prob_key].shape == (5,)
8181
else:
82-
assert td_out["sample_log_prob"].shape == (5,)
82+
assert td_out["data", "action_log_prob"].shape == (5,)
8383

8484
policy = ProbabilisticActor(
8585
module=policy_module,
@@ -99,7 +99,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
9999
if log_prob_key:
100100
assert td_out[log_prob_key].shape == (5,)
101101
else:
102-
assert td_out["sample_log_prob"].shape == (5,)
102+
assert td_out["data", "action_log_prob"].shape == (5,)
103103

104104

105105
@pytest.mark.parametrize(
@@ -144,7 +144,7 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
144144
if log_prob_key:
145145
assert td_out[log_prob_key].shape == (5,)
146146
else:
147-
assert td_out["sample_log_prob"].shape == (5,)
147+
assert td_out["data", "action_log_prob"].shape == (5,)
148148

149149
policy = ProbabilisticActor(
150150
module=policy_module,
@@ -164,7 +164,7 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
164164
if log_prob_key:
165165
assert td_out[log_prob_key].shape == (5,)
166166
else:
167-
assert td_out["sample_log_prob"].shape == (5,)
167+
assert td_out["data", "action_log_prob"].shape == (5,)
168168

169169

170170
class TestQValue:
@@ -867,7 +867,7 @@ def test_lmhead_actorvalueoperator(device):
867867

868868
# check actor
869869
assert aco.module[1].in_keys == ["x"]
870-
assert aco.module[1].out_keys == ["logits", "action", "sample_log_prob"]
870+
assert aco.module[1].out_keys == ["logits", "action", "action_log_prob"]
871871
assert aco.module[1][0].module is base_model.lm_head
872872

873873
# check critic

test/test_exploration.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def test_egreedy_masked(self, module, eps_init, spec_class):
124124
{"observation": torch.zeros(*batch_size, action_size)},
125125
batch_size=batch_size,
126126
)
127-
with pytest.raises(RuntimeError, match="Failed while executing module"):
127+
with pytest.raises(
128+
KeyError, match="Action mask key action_mask not found in TensorDict"
129+
):
128130
explorative_policy(td)
129131

130132
torch.manual_seed(0)
@@ -163,9 +165,7 @@ def test_egreedy_masked(self, module, eps_init, spec_class):
163165
assert not (action[~action_mask] == 0).all()
164166
assert (masked_action[~action_mask] == 0).all()
165167

166-
def test_no_spec_error(
167-
self,
168-
):
168+
def test_no_spec_error(self):
169169
torch.manual_seed(0)
170170
action_size = 4
171171
batch_size = (3, 4, 2)
@@ -183,7 +183,10 @@ def test_no_spec_error(
183183
batch_size=batch_size,
184184
)
185185

186-
with pytest.raises(RuntimeError, match="Failed while executing module"):
186+
with pytest.raises(
187+
RuntimeError,
188+
match="Failed while executing module|spec must be provided to the exploration wrapper",
189+
):
187190
explorative_policy(td)
188191

189192
@pytest.mark.parametrize("module", [True, False])
@@ -200,7 +203,9 @@ def test_wrong_action_shape(self, module):
200203
policy,
201204
)
202205
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
203-
with pytest.raises(RuntimeError, match="Failed while executing module"):
206+
with pytest.raises(
207+
ValueError, match="Action spec shape does not match the action shape"
208+
):
204209
explorative_policy(td)
205210

206211

torchrl/modules/tensordict_module/actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2168,7 +2168,7 @@ def __init__(self, base_model):
21682168
TensorDictModule(
21692169
base_model.transformer,
21702170
in_keys={"input_ids": "input_ids", "attention_mask": "attention_mask"},
2171-
out_keys=["x"],
2171+
out_keys=["x", "_"],
21722172
),
21732173
TensorDictModule(lambda x: x[:, -1, :], in_keys=["x"], out_keys=["x"]),
21742174
)

0 commit comments

Comments
 (0)