Skip to content

Commit 0111a87

Browse files
author
Vincent Moens
committed
[Deprecation] Remove OrnsteinUhlenbeckProcessWrapper
ghstack-source-id: 401fdfa Pull Request resolved: #2749
1 parent 6c7f4fb commit 0111a87

File tree

3 files changed

+9
-215
lines changed

3 files changed

+9
-215
lines changed

docs/source/reference/modules.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ other cases, the action written in the tensordict is simply the network output.
8080
AdditiveGaussianModule
8181
ConsistentDropoutModule
8282
EGreedyModule
83-
EGreedyWrapper
8483
OrnsteinUhlenbeckProcessModule
85-
OrnsteinUhlenbeckProcessWrapper
8684

8785
Probabilistic actors
8886
~~~~~~~~~~~~~~~~~~~~

test/test_exploration.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
EGreedyModule,
3939
EGreedyWrapper,
4040
OrnsteinUhlenbeckProcessModule,
41-
OrnsteinUhlenbeckProcessWrapper,
4241
)
4342

4443
if os.getenv("PYTORCH_TEST_FBCODE"):
@@ -235,7 +234,7 @@ def test_ou_process(self, device, seed=0):
235234
assert pval_acc > 0.05
236235
assert pval_reg < 0.1
237236

238-
@pytest.mark.parametrize("interface", ["module", "wrapper"])
237+
@pytest.mark.parametrize("interface", ["module"])
239238
def test_ou(
240239
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
241240
):
@@ -257,8 +256,7 @@ def test_ou(
257256
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
258257
exploratory_policy = TensorDictSequential(policy, ou)
259258
else:
260-
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
261-
ou = exploratory_policy
259+
raise NotImplementedError
262260

263261
tensordict = TensorDict(
264262
batch_size=[batch],
@@ -299,7 +297,7 @@ def test_ou(
299297

300298
@pytest.mark.parametrize("parallel_spec", [True, False])
301299
@pytest.mark.parametrize("probabilistic", [True, False])
302-
@pytest.mark.parametrize("interface", ["module", "wrapper"])
300+
@pytest.mark.parametrize("interface", ["module"])
303301
def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0):
304302
torch.manual_seed(seed)
305303
env = SerialEnv(
@@ -340,7 +338,7 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0
340338
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
341339
)
342340
else:
343-
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
341+
raise NotImplementedError
344342
exploratory_policy(env.reset())
345343
collector = SyncDataCollector(
346344
create_env_fn=env,
@@ -357,7 +355,7 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0
357355
@pytest.mark.parametrize("nested_obs_action", [True, False])
358356
@pytest.mark.parametrize("nested_done", [True, False])
359357
@pytest.mark.parametrize("is_init_key", ["some"])
360-
@pytest.mark.parametrize("interface", ["module", "wrapper"])
358+
@pytest.mark.parametrize("interface", ["module"])
361359
def test_nested(
362360
self,
363361
device,
@@ -401,12 +399,7 @@ def test_nested(
401399
).to(device),
402400
)
403401
else:
404-
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
405-
policy,
406-
spec=action_spec,
407-
action_key=env.action_key,
408-
is_init_key=is_init_key,
409-
)
402+
raise NotImplementedError
410403
collector = SyncDataCollector(
411404
create_env_fn=env,
412405
policy=exploratory_policy,

torchrl/modules/tensordict_module/exploration.py

Lines changed: 3 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -370,94 +370,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
370370

371371

372372
class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
373-
r"""Ornstein-Uhlenbeck exploration policy wrapper.
374-
375-
Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf.
376-
377-
The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration
378-
noise. This enables a sort of 'structured' exploration.
379-
380-
Noise equation:
381-
382-
.. math::
383-
noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W
384-
385-
Sigma equation:
386-
387-
.. math::
388-
\sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma))
389-
390-
To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys
391-
will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset,
392-
indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive
393-
trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of
394-
zeroing the tensordict at reset time.
395-
396-
.. note::
397-
Once an environment has been wrapped in :class:`OrnsteinUhlenbeckProcessWrapper`, it is
398-
crucial to incorporate a call to :meth:`~.step` in the training loop
399-
to update the exploration factor.
400-
Since it is not easy to capture this omission no warning or exception
401-
will be raised if this is ommitted!
402-
403-
Args:
404-
policy (TensorDictModule): a policy
405-
406-
Keyword Args:
407-
eps_init (scalar): initial epsilon value, determining the amount of noise to be added.
408-
default: 1.0
409-
eps_end (scalar): final epsilon value, determining the amount of noise to be added.
410-
default: 0.1
411-
annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value.
412-
default: 1000
413-
theta (scalar): theta factor in the noise equation
414-
default: 0.15
415-
mu (scalar): OU average (mu in the noise equation).
416-
default: 0.0
417-
sigma (scalar): sigma value in the sigma equation.
418-
default: 0.2
419-
dt (scalar): dt in the noise equation.
420-
default: 0.01
421-
x0 (Tensor, ndarray, optional): initial value of the process.
422-
default: 0.0
423-
sigma_min (number, optional): sigma_min in the sigma equation.
424-
default: None
425-
n_steps_annealing (int): number of steps for the sigma annealing.
426-
default: 1000
427-
action_key (NestedKey, optional): key of the action to be modified.
428-
default: "action"
429-
is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
430-
default: "is_init"
431-
spec (TensorSpec, optional): if provided, the sampled action will be
432-
projected onto the valid action space once explored. If not provided,
433-
the exploration wrapper will attempt to recover it from the policy.
434-
safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
435-
given the :obj:`TensorSpec.project` heuristic.
436-
default: True
437-
device (torch.device, optional): the device where the buffers have to be stored.
438-
439-
Examples:
440-
>>> import torch
441-
>>> from tensordict import TensorDict
442-
>>> from torchrl.data import Bounded
443-
>>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor
444-
>>> torch.manual_seed(0)
445-
>>> spec = Bounded(-1, 1, torch.Size([4]))
446-
>>> module = torch.nn.Linear(4, 4, bias=False)
447-
>>> policy = Actor(module=module, spec=spec)
448-
>>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy)
449-
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
450-
>>> print(explorative_policy(td))
451-
TensorDict(
452-
fields={
453-
_ou_prev_noise: Tensor(torch.Size([10, 4]), dtype=torch.float32),
454-
_ou_steps: Tensor(torch.Size([10, 1]), dtype=torch.int64),
455-
action: Tensor(torch.Size([10, 4]), dtype=torch.float32),
456-
observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
457-
batch_size=torch.Size([10]),
458-
device=None,
459-
is_shared=False)
460-
"""
373+
"""[Deprecated] Ornstein-Uhlenbeck exploration policy wrapper."""
461374

462375
def __init__(
463376
self,
@@ -480,119 +393,9 @@ def __init__(
480393
key: Optional[NestedKey] = None,
481394
device: torch.device | None = None,
482395
):
483-
warnings.warn(
484-
"OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed "
485-
"in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule "
486-
"instead.",
487-
category=DeprecationWarning,
488-
)
489-
if device is None and hasattr(policy, "parameters"):
490-
for p in policy.parameters():
491-
device = p.device
492-
break
493-
if key is not None:
494-
action_key = key
495-
warnings.warn(
496-
f"the 'key' keyword argument of {type(self)} has been renamed 'action_key'. The 'key' entry will be deprecated soon."
497-
)
498-
super().__init__(policy)
499-
self.ou = _OrnsteinUhlenbeckProcess(
500-
theta=theta,
501-
mu=mu,
502-
sigma=sigma,
503-
dt=dt,
504-
x0=x0,
505-
sigma_min=sigma_min,
506-
n_steps_annealing=n_steps_annealing,
507-
key=action_key,
508-
device=device,
509-
)
510-
self.register_buffer("eps_init", torch.tensor(eps_init, device=device))
511-
self.register_buffer("eps_end", torch.tensor(eps_end, device=device))
512-
if self.eps_end > self.eps_init:
513-
raise ValueError(
514-
"eps should decrease over time or be constant, "
515-
f"got eps_init={eps_init} and eps_end={eps_end}"
516-
)
517-
self.annealing_num_steps = annealing_num_steps
518-
self.register_buffer(
519-
"eps", torch.tensor(eps_init, dtype=torch.float32, device=device)
396+
raise RuntimeError(
397+
"OrnsteinUhlenbeckProcessWrapper has been removed. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule instead."
520398
)
521-
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
522-
self.is_init_key = is_init_key
523-
noise_key = self.ou.noise_key
524-
steps_key = self.ou.steps_key
525-
526-
if spec is not None:
527-
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
528-
spec = Composite({action_key: spec}, shape=spec.shape[:-1])
529-
self._spec = spec
530-
elif hasattr(self.td_module, "_spec"):
531-
self._spec = self.td_module._spec.clone()
532-
if action_key not in self._spec.keys(True, True):
533-
self._spec[action_key] = None
534-
elif hasattr(self.td_module, "spec"):
535-
self._spec = self.td_module.spec.clone()
536-
if action_key not in self._spec.keys(True, True):
537-
self._spec[action_key] = None
538-
else:
539-
self._spec = Composite({key: None for key in policy.out_keys})
540-
ou_specs = {
541-
noise_key: None,
542-
steps_key: None,
543-
}
544-
self._spec.update(ou_specs)
545-
if len(set(self.out_keys)) != len(self.out_keys):
546-
raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}")
547-
self.safe = safe
548-
if self.safe:
549-
self.register_forward_hook(_forward_hook_safe_action)
550-
551-
@property
552-
def spec(self):
553-
return self._spec
554-
555-
def step(self, frames: int = 1) -> None:
556-
"""Updates the eps noise factor.
557-
558-
Args:
559-
frames (int): number of frames of the current batch (corresponding to the number of updates to be made).
560-
561-
"""
562-
for _ in range(frames):
563-
if self.annealing_num_steps > 0:
564-
self.eps.data.copy_(
565-
torch.maximum(
566-
self.eps_end,
567-
(
568-
self.eps
569-
- (self.eps_init - self.eps_end) / self.annealing_num_steps
570-
),
571-
)
572-
)
573-
else:
574-
raise ValueError(
575-
f"{self.__class__.__name__}.step() called when "
576-
f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive "
577-
f"number of frames."
578-
)
579-
580-
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
581-
tensordict = super().forward(tensordict)
582-
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
583-
is_init = tensordict.get(self.is_init_key, None)
584-
if is_init is None:
585-
warnings.warn(
586-
f"The tensordict passed to {self.__class__.__name__} appears to be "
587-
f"missing the '{self.is_init_key}' entry. This entry is used to "
588-
f"reset the noise at the beginning of a trajectory, without it "
589-
f"the behavior of this exploration method is undefined. "
590-
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
591-
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
592-
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
593-
)
594-
tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init)
595-
return tensordict
596399

597400

598401
class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):

0 commit comments

Comments
 (0)