Skip to content

Commit a38604e

Browse files
author
Vincent Moens
committed
[Deprecation] Remove NormalParamWrapper
ghstack-source-id: 4a70178 Pull Request resolved: #2747
1 parent 12e6bce commit a38604e

File tree

4 files changed

+12
-110
lines changed

4 files changed

+12
-110
lines changed

docs/source/reference/modules.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ Some distributions are typically used in RL scripts.
439439

440440
Delta
441441
IndependentNormal
442-
NormalParamWrapper
443442
TanhNormal
444443
TruncatedNormal
445444
TanhDelta

test/test_distributions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
import torch.nn.functional as F
1414

1515
from tensordict import TensorDictBase
16+
from tensordict.nn import NormalParamExtractor
1617
from torch import autograd, nn
1718
from torch.utils._pytree import tree_map
1819
from torchrl.modules import (
19-
NormalParamWrapper,
2020
OneHotCategorical,
2121
OneHotOrdinal,
2222
Ordinal,
@@ -310,14 +310,19 @@ def test_normal_mapping(batch_size, device, scale_mapping, action_dim=11, state_
310310
torch.manual_seed(0)
311311
for _ in range(100):
312312
module = nn.LazyLinear(2 * action_dim).to(device)
313-
module = NormalParamWrapper(module, scale_mapping=scale_mapping).to(device)
314313
if scale_mapping != "raise_error":
314+
module = nn.Sequential(
315+
module, NormalParamExtractor(scale_mapping=scale_mapping)
316+
).to(device)
315317
loc, scale = module(torch.randn(*batch_size, state_dim, device=device))
316318
assert (scale > 0).all()
317319
else:
318320
with pytest.raises(
319321
NotImplementedError, match="Unknown mapping " "raise_error"
320322
):
323+
module = nn.Sequential(
324+
module, NormalParamExtractor(scale_mapping=scale_mapping)
325+
).to(device)
321326
loc, scale = module(torch.randn(*batch_size, state_dim, device=device))
322327

323328

torchrl/modules/distributions/continuous.py

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import warnings
87
import weakref
98
from numbers import Number
10-
from typing import Dict, Optional, Sequence, Tuple, Union
9+
from typing import Dict, Optional, Sequence, Union
1110

1211
import numpy as np
1312
import torch
@@ -27,7 +26,6 @@
2726
safeatanh_noeps,
2827
safetanh_noeps,
2928
)
30-
from torchrl.modules.utils import mappings
3129

3230
# speeds up distribution construction
3331
D.Distribution.set_default_validate_args(False)
@@ -126,60 +124,16 @@ def inv(self):
126124
return inv
127125

128126

129-
class NormalParamWrapper(nn.Module):
130-
"""A wrapper for normal distribution parameters.
131-
132-
Args:
133-
operator (nn.Module): operator whose output will be transformed_in in location and scale parameters
134-
scale_mapping (str, optional): positive mapping function to be used with the std.
135-
default = "biased_softplus_1.0" (i.e. softplus map with bias such that fn(0.0) = 1.0)
136-
choices: "softplus", "exp", "relu", "biased_softplus_1";
137-
scale_lb (Number, optional): The minimum value that the variance can take. Default is 1e-4.
138-
139-
Examples:
140-
>>> from torch import nn
141-
>>> import torch
142-
>>> module = nn.Linear(3, 4)
143-
>>> module_normal = NormalParamWrapper(module)
144-
>>> tensor = torch.randn(3)
145-
>>> loc, scale = module_normal(tensor)
146-
>>> print(loc.shape, scale.shape)
147-
torch.Size([2]) torch.Size([2])
148-
>>> assert (scale > 0).all()
149-
>>> # with modules that return more than one tensor
150-
>>> module = nn.LSTM(3, 4)
151-
>>> module_normal = NormalParamWrapper(module)
152-
>>> tensor = torch.randn(4, 2, 3)
153-
>>> loc, scale, others = module_normal(tensor)
154-
>>> print(loc.shape, scale.shape)
155-
torch.Size([4, 2, 2]) torch.Size([4, 2, 2])
156-
>>> assert (scale > 0).all()
157-
158-
"""
159-
127+
class NormalParamWrapper(nn.Module): # noqa: D101
160128
def __init__(
161129
self,
162130
operator: nn.Module,
163131
scale_mapping: str = "biased_softplus_1.0",
164132
scale_lb: Number = 1e-4,
165133
) -> None:
166-
warnings.warn(
167-
"The NormalParamWrapper class will be deprecated in v0.7 in favor of :class:`~tensordict.nn.NormalParamExtractor`.",
168-
category=DeprecationWarning,
134+
raise RuntimeError(
135+
"NormalParamWrapper has been deprecated in favor of `tensordict.nn.NormalParamExtractor`. Use this class instead."
169136
)
170-
super().__init__()
171-
self.operator = operator
172-
self.scale_mapping = scale_mapping
173-
self.scale_lb = scale_lb
174-
175-
def forward(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor]:
176-
net_output = self.operator(*tensors)
177-
others = ()
178-
if not isinstance(net_output, torch.Tensor):
179-
net_output, *others = net_output
180-
loc, scale = net_output.chunk(2, -1)
181-
scale = mappings(self.scale_mapping)(scale).clamp_min(self.scale_lb)
182-
return (loc, scale, *others)
183137

184138

185139
class TruncatedNormal(D.Independent):

torchrl/modules/utils/mappings.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Callable
7-
8-
import torch
9-
from tensordict.nn.utils import biased_softplus, inv_softplus
6+
from tensordict.nn.utils import biased_softplus, expln, inv_softplus, mappings
107

118
__all__ = ["biased_softplus", "expln", "inv_softplus", "mappings"]
12-
13-
14-
def expln(x):
15-
"""A smooth, continuous positive mapping presented in "State-Dependent Exploration for Policy Gradient Methods".
16-
17-
https://people.idsia.ch/~juergen/ecml2008rueckstiess.pdf
18-
19-
"""
20-
out = torch.empty_like(x)
21-
idx_neg = x <= 0
22-
out[idx_neg] = x[idx_neg].exp()
23-
out[~idx_neg] = x[~idx_neg].log1p() + 1
24-
return out
25-
26-
27-
def mappings(key: str) -> Callable:
28-
"""Given an input string, returns a surjective function f(x): R -> R^+.
29-
30-
Args:
31-
key (str): one of "softplus", "exp", "relu", "expln",
32-
or "biased_softplus". If the key beggins with "biased_softplus",
33-
then it needs to take the following form:
34-
```"biased_softplus_{bias}"``` where ```bias``` can be converted to a floating point number that will be used to bias the softplus function.
35-
Alternatively, the ```"biased_softplus_{bias}_{min_val}"``` syntax can be used. In that case, the additional ```min_val``` term is a floating point
36-
number that will be used to encode the minimum value of the softplus transform.
37-
In practice, the equation used is softplus(x + bias) + min_val, where bias and min_val are values computed such that the conditions above are met.
38-
39-
Returns:
40-
a Callable
41-
42-
"""
43-
_mappings = {
44-
"softplus": torch.nn.functional.softplus,
45-
"exp": torch.exp,
46-
"relu": torch.relu,
47-
"biased_softplus": biased_softplus(1.0),
48-
"expln": expln,
49-
}
50-
if key in _mappings:
51-
return _mappings[key]
52-
elif key.startswith("biased_softplus"):
53-
stripped_key = key.split("_")
54-
if len(stripped_key) == 3:
55-
return biased_softplus(float(stripped_key[-1]))
56-
elif len(stripped_key) == 4:
57-
return biased_softplus(
58-
float(stripped_key[-2]), min_val=float(stripped_key[-1])
59-
)
60-
else:
61-
raise ValueError(f"Invalid number of args in {key}")
62-
63-
else:
64-
raise NotImplementedError(f"Unknown mapping {key}")

0 commit comments

Comments
 (0)