Skip to content

Commit 8dceee8

Browse files
[BugFix]: Fix additive noise (#447)
1 parent e2a518c commit 8dceee8

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

test/test_exploration.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -198,20 +198,21 @@ def test_additivegaussian_wrapper(
198198
)
199199
out_noexp = []
200200
out = []
201-
for _ in range(n_steps):
202-
tensordict_noexp = policy(tensordict.select("observation"))
203-
tensordict = exploratory_policy(tensordict)
204-
out.append(tensordict.clone())
205-
out_noexp.append(tensordict_noexp.clone())
206-
tensordict.set_("observation", torch.randn(batch, d_obs, device=device))
207-
out = torch.stack(out, 0)
208-
out_noexp = torch.stack(out_noexp, 0)
209-
assert (out_noexp.get("action") != out.get("action")).all()
210-
if spec_origin is not None:
211-
assert (out.get("action") <= 1.0).all(), out.get("action").min()
212-
assert (out.get("action") >= -1.0).all(), out.get("action").max()
213-
if action_spec is not None:
214-
assert action_spec.is_in(out.get("action"))
201+
if exploratory_policy.spec is not None:
202+
for _ in range(n_steps):
203+
tensordict_noexp = policy(tensordict.select("observation"))
204+
tensordict = exploratory_policy(tensordict)
205+
out.append(tensordict.clone())
206+
out_noexp.append(tensordict_noexp.clone())
207+
tensordict.set_("observation", torch.randn(batch, d_obs, device=device))
208+
out = torch.stack(out, 0)
209+
out_noexp = torch.stack(out_noexp, 0)
210+
assert (out_noexp.get("action") != out.get("action")).all()
211+
if spec_origin is not None:
212+
assert (out.get("action") <= 1.0).all(), out.get("action").min()
213+
assert (out.get("action") >= -1.0).all(), out.get("action").max()
214+
if action_spec is not None:
215+
assert action_spec.is_in(out.get("action"))
215216

216217

217218
@pytest.mark.parametrize("state_dim", [7])

torchrl/modules/tensordict_module/exploration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
224224
tensordict = self.td_module.forward(tensordict)
225225
if exploration_mode() == "random" or exploration_mode() is None:
226226
out = tensordict.get(self.action_key)
227+
out = self._add_noise(out)
227228
tensordict.set(self.action_key, out)
228229
return tensordict
229230

0 commit comments

Comments
 (0)