Skip to content

Commit 6c7f4fb

Browse files
author
Vincent Moens
committed
[Deprecation] Remove AdditiveGaussianWrapper
ghstack-source-id: 78f248e Pull Request resolved: #2748
1 parent a38604e commit 6c7f4fb

File tree

4 files changed

+11
-147
lines changed

4 files changed

+11
-147
lines changed

docs/source/reference/modules.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ other cases, the action written in the tensordict is simply the network output.
7878
:template: rl_template_noinherit.rst
7979

8080
AdditiveGaussianModule
81-
AdditiveGaussianWrapper
8281
ConsistentDropoutModule
8382
EGreedyModule
8483
EGreedyWrapper

test/test_exploration.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from torchrl.modules.tensordict_module.exploration import (
3636
_OrnsteinUhlenbeckProcess,
3737
AdditiveGaussianModule,
38-
AdditiveGaussianWrapper,
3938
EGreedyModule,
4039
EGreedyWrapper,
4140
OrnsteinUhlenbeckProcessModule,
@@ -433,7 +432,7 @@ def test_no_spec_error(self, device):
433432
@pytest.mark.parametrize("device", get_default_devices())
434433
class TestAdditiveGaussian:
435434
@pytest.mark.parametrize("spec_origin", ["spec", "policy", None])
436-
@pytest.mark.parametrize("interface", ["module", "wrapper"])
435+
@pytest.mark.parametrize("interface", ["module"])
437436
def test_additivegaussian_sd(
438437
self,
439438
device,
@@ -475,8 +474,8 @@ def test_additivegaussian_sd(
475474
default_interaction_type=InteractionType.RANDOM,
476475
)
477476
given_spec = action_spec if spec_origin == "spec" else None
478-
exploratory_policy = AdditiveGaussianWrapper(
479-
policy, spec=given_spec, device=device
477+
exploratory_policy = TensorDictModule(
478+
policy, AdditiveGaussianModule(spec=given_spec, device=device)
480479
)
481480
if spec_origin is not None:
482481
sigma_init = (
@@ -524,7 +523,7 @@ def test_additivegaussian_sd(
524523
assert abs(noisy_action.std() - sigma_end) < 1e-1
525524

526525
@pytest.mark.parametrize("spec_origin", ["spec", "policy", None])
527-
@pytest.mark.parametrize("interface", ["module", "wrapper"])
526+
@pytest.mark.parametrize("interface", ["module"])
528527
def test_additivegaussian(
529528
self,
530529
device,
@@ -563,9 +562,7 @@ def test_additivegaussian(
563562
policy, AdditiveGaussianModule(spec=given_spec).to(device)
564563
)
565564
else:
566-
exploratory_policy = AdditiveGaussianWrapper(
567-
policy, spec=given_spec, safe=False
568-
).to(device)
565+
raise NotImplementedError
569566

570567
tensordict = TensorDict(
571568
batch_size=[batch],
@@ -590,7 +587,7 @@ def test_additivegaussian(
590587
assert action_spec.is_in(out.get("action"))
591588

592589
@pytest.mark.parametrize("parallel_spec", [True, False])
593-
@pytest.mark.parametrize("interface", ["module", "wrapper"])
590+
@pytest.mark.parametrize("interface", ["module"])
594591
def test_collector(self, device, parallel_spec, interface, seed=0):
595592
torch.manual_seed(seed)
596593
env = SerialEnv(
@@ -622,7 +619,7 @@ def test_collector(self, device, parallel_spec, interface, seed=0):
622619
policy, AdditiveGaussianModule(spec=action_spec).to(device)
623620
)
624621
else:
625-
exploratory_policy = AdditiveGaussianWrapper(policy, safe=False)
622+
raise NotImplementedError
626623
exploratory_policy(env.reset())
627624
collector = SyncDataCollector(
628625
create_env_fn=env,

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def cudagraph_mark_step_begin():
8181
INSTANTIATE_TIMEOUT = 20
8282
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
8383
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
84-
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", float("inf")))
84+
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))
8585

8686
DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM
8787

torchrl/modules/tensordict_module/exploration.py

Lines changed: 3 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
"EGreedyWrapper",
2828
"EGreedyModule",
2929
"AdditiveGaussianModule",
30-
"AdditiveGaussianWrapper",
3130
"OrnsteinUhlenbeckProcessModule",
3231
"OrnsteinUhlenbeckProcessWrapper",
3332
]
@@ -220,42 +219,7 @@ def __init__(
220219

221220

222221
class AdditiveGaussianWrapper(TensorDictModuleWrapper):
223-
"""Additive Gaussian PO wrapper.
224-
225-
Args:
226-
policy (TensorDictModule): a policy.
227-
228-
Keyword Args:
229-
sigma_init (scalar, optional): initial epsilon value.
230-
default: 1.0
231-
sigma_end (scalar, optional): final epsilon value.
232-
default: 0.1
233-
annealing_num_steps (int, optional): number of steps it will take for
234-
sigma to reach the :obj:`sigma_end` value.
235-
mean (:obj:`float`, optional): mean of each output element’s normal distribution.
236-
std (:obj:`float`, optional): standard deviation of each output element’s normal distribution.
237-
action_key (NestedKey, optional): if the policy module has more than one output key,
238-
its output spec will be of type Composite. One needs to know where to
239-
find the action spec.
240-
Default is "action".
241-
spec (TensorSpec, optional): if provided, the sampled action will be
242-
projected onto the valid action space once explored. If not provided,
243-
the exploration wrapper will attempt to recover it from the policy.
244-
safe (boolean, optional): if False, the TensorSpec can be None. If it
245-
is set to False but the spec is passed, the projection will still
246-
happen.
247-
Default is True.
248-
device (torch.device, optional): the device where the buffers have to be stored.
249-
250-
.. note::
251-
Once an environment has been wrapped in :class:`AdditiveGaussianWrapper`, it is
252-
crucial to incorporate a call to :meth:`~.step` in the training loop
253-
to update the exploration factor.
254-
Since it is not easy to capture this omission no warning or exception
255-
will be raised if this is ommitted!
256-
257-
258-
"""
222+
"""[Deprecated] Additive Gaussian PO wrapper."""
259223

260224
def __init__(
261225
self,
@@ -271,105 +235,9 @@ def __init__(
271235
safe: Optional[bool] = True,
272236
device: torch.device | None = None,
273237
):
274-
warnings.warn(
275-
"AdditiveGaussianWrapper is deprecated and will be removed "
276-
"in v0.7. Please use torchrl.modules.AdditiveGaussianModule "
277-
"instead.",
278-
category=DeprecationWarning,
279-
)
280-
if device is None and hasattr(policy, "parameters"):
281-
for p in policy.parameters():
282-
device = p.device
283-
break
284-
285-
super().__init__(policy)
286-
if sigma_end > sigma_init:
287-
raise RuntimeError("sigma should decrease over time or be constant")
288-
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
289-
self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device))
290-
self.annealing_num_steps = annealing_num_steps
291-
self.register_buffer("mean", torch.tensor(mean, device=device))
292-
self.register_buffer("std", torch.tensor(std, device=device))
293-
self.register_buffer(
294-
"sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device)
238+
raise RuntimeError(
239+
"This module has been removed from TorchRL. Please use torchrl.modules.AdditiveGaussianModule instead."
295240
)
296-
self.action_key = action_key
297-
self.out_keys = list(self.td_module.out_keys)
298-
if action_key not in self.out_keys:
299-
raise RuntimeError(
300-
f"The action key {action_key} was not found in the td_module out_keys {self.td_module.out_keys}."
301-
)
302-
if spec is not None:
303-
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
304-
spec = Composite({action_key: spec}, shape=spec.shape[:-1])
305-
self._spec = spec
306-
elif hasattr(self.td_module, "_spec"):
307-
self._spec = self.td_module._spec.clone()
308-
if action_key not in self._spec.keys(True, True):
309-
self._spec[action_key] = None
310-
elif hasattr(self.td_module, "spec"):
311-
self._spec = self.td_module.spec.clone()
312-
if action_key not in self._spec.keys(True, True):
313-
self._spec[action_key] = None
314-
else:
315-
self._spec = Composite({key: None for key in policy.out_keys})
316-
317-
self.safe = safe
318-
if self.safe:
319-
self.register_forward_hook(_forward_hook_safe_action)
320-
321-
@property
322-
def spec(self):
323-
return self._spec
324-
325-
def step(self, frames: int = 1) -> None:
326-
"""A step of sigma decay.
327-
328-
After self.annealing_num_steps, this function is a no-op.
329-
330-
Args:
331-
frames (int): number of frames since last step.
332-
333-
"""
334-
for _ in range(frames):
335-
self.sigma.data.copy_(
336-
torch.maximum(
337-
self.sigma_end,
338-
self.sigma
339-
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps,
340-
),
341-
)
342-
343-
def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
344-
sigma = self.sigma
345-
mean = self.mean.expand(action.shape)
346-
std = self.std.expand(action.shape)
347-
if not mean.dtype.is_floating_point:
348-
mean = mean.to(torch.get_default_dtype())
349-
if not std.dtype.is_floating_point:
350-
std = std.to(torch.get_default_dtype())
351-
noise = torch.normal(mean=mean, std=std)
352-
if noise.device != action.device:
353-
noise = noise.to(action.device)
354-
action = action + noise * sigma
355-
spec = self.spec
356-
spec = spec[self.action_key]
357-
if spec is not None:
358-
action = spec.project(action)
359-
elif self.safe:
360-
raise RuntimeError(
361-
"the action spec must be provided to AdditiveGaussianWrapper unless "
362-
"the `safe` keyword argument is turned off at initialization."
363-
)
364-
return action
365-
366-
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
367-
tensordict = self.td_module.forward(tensordict)
368-
if exploration_type() is ExplorationType.RANDOM or exploration_type() is None:
369-
out = tensordict.get(self.action_key)
370-
out = self._add_noise(out)
371-
tensordict.set(self.action_key, out)
372-
return tensordict
373241

374242

375243
class AdditiveGaussianModule(TensorDictModuleBase):

0 commit comments

Comments
 (0)