Skip to content

Commit 840b6b6

Browse files
authored
[BugFix] Fix AdditiveGaussian exploration tests (#450)
1 parent 85fdbab commit 840b6b6

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

test/test_exploration.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def test_additivegaussian_wrapper(
189189
default_interaction_mode="random",
190190
).to(device)
191191
given_spec = action_spec if spec_origin == "spec" else None
192-
exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(device)
192+
exploratory_policy = AdditiveGaussianWrapper(
193+
policy, spec=given_spec, safe=False
194+
).to(device)
193195

194196
tensordict = TensorDict(
195197
batch_size=[batch],
@@ -198,21 +200,20 @@ def test_additivegaussian_wrapper(
198200
)
199201
out_noexp = []
200202
out = []
201-
if exploratory_policy.spec is not None:
202-
for _ in range(n_steps):
203-
tensordict_noexp = policy(tensordict.select("observation"))
204-
tensordict = exploratory_policy(tensordict)
205-
out.append(tensordict.clone())
206-
out_noexp.append(tensordict_noexp.clone())
207-
tensordict.set_("observation", torch.randn(batch, d_obs, device=device))
208-
out = torch.stack(out, 0)
209-
out_noexp = torch.stack(out_noexp, 0)
210-
assert (out_noexp.get("action") != out.get("action")).all()
211-
if spec_origin is not None:
212-
assert (out.get("action") <= 1.0).all(), out.get("action").min()
213-
assert (out.get("action") >= -1.0).all(), out.get("action").max()
214-
if action_spec is not None:
215-
assert action_spec.is_in(out.get("action"))
203+
for _ in range(n_steps):
204+
tensordict_noexp = policy(tensordict.select("observation"))
205+
tensordict = exploratory_policy(tensordict)
206+
out.append(tensordict.clone())
207+
out_noexp.append(tensordict_noexp.clone())
208+
tensordict.set_("observation", torch.randn(batch, d_obs, device=device))
209+
out = torch.stack(out, 0)
210+
out_noexp = torch.stack(out_noexp, 0)
211+
assert (out_noexp.get("action") != out.get("action")).all()
212+
if spec_origin is not None:
213+
assert (out.get("action") <= 1.0).all(), out.get("action").min()
214+
assert (out.get("action") >= -1.0).all(), out.get("action").max()
215+
if action_spec is not None:
216+
assert action_spec.is_in(out.get("action"))
216217

217218

218219
@pytest.mark.parametrize("state_dim", [7])

0 commit comments

Comments
 (0)