Skip to content

Commit 2f8c118

Browse files
author
Vincent Moens
committed
[BugFix] Fix safe probabilistic backward by removing in-place modif
ghstack-source-id: 574eb1f Pull Request resolved: #2755
1 parent ee4006a commit 2f8c118

File tree

4 files changed

+161
-83
lines changed

4 files changed

+161
-83
lines changed

torchrl/data/tensor_specs.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
unravel_key,
4545
)
4646
from tensordict.base import NO_DEFAULT
47-
from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey
47+
from tensordict.utils import (
48+
_getitem_batch_size,
49+
expand_as_right,
50+
is_non_tensor,
51+
NestedKey,
52+
)
4853
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for
4954

5055
try:
@@ -1848,9 +1853,8 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
18481853
gathered = mask_expand & val
18491854
oob = ~gathered.any(-1)
18501855
new_val = torch.multinomial(mask_expand[oob].float(), 1)
1851-
val = val.clone()
1852-
val[oob] = 0
1853-
val[oob] = torch.scatter(val[oob], -1, new_val, 1)
1856+
new_val = torch.scatter(torch.zeros_like(val[oob]), -1, new_val, 1)
1857+
val = val.masked_scatter(expand_as_right(oob, val), new_val)
18541858
return val
18551859

18561860
def is_in(self, val: torch.Tensor) -> bool:
@@ -2300,18 +2304,9 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
23002304
if self.device != val.device:
23012305
low = low.to(val.device)
23022306
high = high.to(val.device)
2303-
try:
2304-
val = torch.maximum(torch.minimum(val, high), low)
2305-
except ValueError:
2306-
low = low.expand_as(val)
2307-
high = high.expand_as(val)
2308-
val[val < low] = low[val < low]
2309-
val[val > high] = high[val > high]
2310-
except RuntimeError:
2311-
low = low.expand_as(val)
2312-
high = high.expand_as(val)
2313-
val[val < low] = low[val < low]
2314-
val[val > high] = high[val > high]
2307+
low = low.expand_as(val)
2308+
high = high.expand_as(val)
2309+
val = torch.clamp(val, low, high)
23152310
return val
23162311

23172312
def is_in(self, val: torch.Tensor) -> bool:

torchrl/envs/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9035,7 +9035,7 @@ def _reset(
90359035

90369036

90379037
class BatchSizeTransform(Transform):
9038-
"""A transform to modify the batch-size of an environmt.
9038+
"""A transform to modify the batch-size of an environment.
90399039
90409040
This transform has two distinct usages: it can be used to set the
90419041
batch-size for non-batch-locked (e.g. stateless) environments to

torchrl/modules/tensordict_module/common.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,17 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
7575
for _spec, _key in zip(values, keys):
7676
if _spec is None:
7777
continue
78-
item = tensordict_out.get(_key, None)
78+
item = tensordict_out.get(_key)
7979
if item is None:
8080
# this will happen when an exploration (e.g. OU) writes a key only
8181
# during exploration, but is missing otherwise.
8282
# it's fine since what we want here it to make sure that a key
8383
# is within bounds if it is present
8484
continue
85-
if not _spec.is_in(item):
86-
try:
87-
tensordict_out.set_(
88-
_key,
89-
_spec.project(tensordict_out.get(_key)),
90-
)
91-
except RuntimeError:
92-
tensordict_out.set(
93-
_key,
94-
_spec.project(tensordict_out.get(_key)),
95-
)
85+
tensordict_out.set(
86+
_key,
87+
_spec.project(item),
88+
)
9689
except RuntimeError as err:
9790
if re.search(
9891
"attempting to use a Tensor in some data-dependent control flow", str(err)

torchrl/modules/tensordict_module/probabilistic.py

Lines changed: 144 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from __future__ import annotations
66

77
import warnings
8-
from typing import Dict, List, Optional, Type, Union
8+
from typing import Dict, List, Optional, Union
9+
10+
import torch
911

1012
from tensordict import TensorDictBase, unravel_key_list
1113

@@ -23,95 +25,181 @@
2325

2426

2527
class SafeProbabilisticModule(ProbabilisticTensorDictModule):
26-
""":class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as argument to control the output domain.
28+
""":class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as an argument to control the output domain.
29+
30+
`SafeProbabilisticModule` is a non-parametric module embedding a
31+
probability distribution constructor. It reads the distribution parameters from an input
32+
TensorDict using the specified `in_keys` and outputs a sample (loosely speaking) of the
33+
distribution.
34+
35+
The output "sample" is produced given some rule, specified by the input ``default_interaction_type``
36+
argument and the ``interaction_type()`` global function.
2737
28-
`SafeProbabilisticModule` is a non-parametric module representing a
29-
probability distribution. It reads the distribution parameters from an input
30-
TensorDict using the specified `in_keys`. The output is sampled given some rule,
31-
specified by the input ``default_interaction_type`` argument and the
32-
``interaction_type()`` global function.
38+
`SafeProbabilisticModule` can be used to construct the distribution
39+
(through the :meth:`~.get_dist` method) and/or sampling from this distribution
40+
(through a regular :meth:`~.__call__` to the module).
3341
34-
:obj:`SafeProbabilisticModule` can be used to construct the distribution
35-
(through the :obj:`get_dist()` method) and/or sampling from this distribution
36-
(through a regular :obj:`__call__()` to the module).
42+
A `SafeProbabilisticModule` instance has two main features:
3743
38-
A :obj:`SafeProbabilisticModule` instance has two main features:
39-
- It reads and writes TensorDict objects
44+
- It reads and writes from and to TensorDict objects;
4045
- It uses a real mapping R^n -> R^m to create a distribution in R^d from
41-
which values can be sampled or computed.
46+
which values can be sampled or computed.
4247
43-
When the :obj:`__call__` / :obj:`forward` method is called, a distribution is
44-
created, and a value computed (using the 'mean', 'mode', 'median' attribute or
45-
the 'rsample', 'sample' method). The sampling step is skipped if the supplied
46-
TensorDict has all of the desired key-value pairs already.
48+
When the :meth:`~.__call__` and :meth:`~.forward` method are called, a distribution is
49+
created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean',
50+
'dist.mode', 'dist.median' attributes could be used, as well as
51+
the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied
52+
TensorDict has all the desired key-value pairs already.
4753
48-
By default, SafeProbabilisticModule distribution class is a Delta
49-
distribution, making SafeProbabilisticModule a simple wrapper around
54+
By default, `SafeProbabilisticModule` distribution class is a :class:`~torchrl.modules.distributions.Delta`
55+
distribution, making `SafeProbabilisticModule` a simple wrapper around
5056
a deterministic mapping function.
5157
58+
This class differs from :class:`tensordict.nn.ProbabilisticTensorDictModule` in that it accepts a :attr:`spec`
59+
keyword argument which can be used to control whether samples belong to the distribution or not. The :attr:`safe`
60+
keyword argument controls whether the samples values should be checked against the spec.
61+
5262
Args:
53-
in_keys (NestedKey or list of NestedKey or dict): key(s) that will be read from the
54-
input TensorDict and used to build the distribution. Importantly, if it's an
55-
list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
56-
the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for
57-
the Normal distribution and similar. If in_keys is a dictionary, the keys
58-
are the keys of the distribution and the values are the keys in the
63+
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict
64+
and used to build the distribution.
65+
Importantly, if it's a list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
66+
the distribution class of interest, e.g. ``"loc"`` and ``"scale"`` for
67+
the :class:`~torch.distributions.Normal` distribution and similar.
68+
If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the
5969
tensordict that will get match to the corresponding distribution keys.
60-
out_keys (NestedKey or list of NestedKey): keys where the sampled values will be
61-
written. Importantly, if these keys are found in the input TensorDict, the
62-
sampling step will be skipped.
70+
out_keys (NestedKey | List[NestedKey] | None): key(s) where the sampled values will be written.
71+
Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.
6372
spec (TensorSpec): specs of the first output tensor. Used when calling
6473
td_module.random() to generate random values in the target space.
74+
75+
Keyword Args:
6576
safe (bool, optional): if ``True``, the value of the sample is checked against the
6677
input spec. Out-of-domain sampling can occur because of exploration policies
6778
or numerical under/overflow issues. As for the :obj:`spec` argument, this
6879
check will only occur for the distribution sample, but not the other tensors
6980
returned by the input module. If the sample is out of bounds, it is
7081
projected back onto the desired space using the `TensorSpec.project` method.
7182
Default is ``False``.
72-
default_interaction_type (tensordict.nn.InteractionType, optional): default method to be used to retrieve
73-
the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or ``InteractionType.RANDOM``
83+
default_interaction_type (InteractionType, optional): keyword-only argument.
84+
Default method to be used to retrieve
85+
the output value. Should be one of InteractionType: MODE, MEDIAN, MEAN or RANDOM
7486
(in which case the value is sampled randomly from the distribution). Default
75-
is ``InteractionType.MODE``.
76-
Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will
77-
fist look for the interaction mode dictated by the `interaction_type()`
78-
global function. If this returns `None` (its default value), then the
79-
`default_interaction_type` of the :class:`~.ProbabilisticTDModule`
80-
instance will be used. Note that DataCollector instances will use
81-
:func:`tensordict.nn.set_interaction_type` to
82-
:class:`tensordict.nn.InteractionType.RANDOM` by default.
83-
distribution_class (Type, optional): a torch.distributions.Distribution class to
84-
be used for sampling. Default is Delta.
85-
distribution_kwargs (dict, optional): kwargs to be passed to the distribution.
86-
return_log_prob (bool, optional): if ``True``, the log-probability of the
87+
is MODE.
88+
89+
.. note:: When a sample is drawn, the
90+
:class:`ProbabilisticTensorDictModule` instance will
91+
first look for the interaction mode dictated by the
92+
:func:`~tensordict.nn.probabilistic.interaction_type`
93+
global function. If this returns `None` (its default value), then the
94+
`default_interaction_type` of the `ProbabilisticTDModule`
95+
instance will be used. Note that
96+
:class:`~torchrl.collectors.collectors.DataCollectorBase`
97+
instances will use `set_interaction_type` to
98+
:class:`tensordict.nn.InteractionType.RANDOM` by default.
99+
100+
.. note::
101+
In some cases, the mode, median or mean value may not be
102+
readily available through the corresponding attribute.
103+
To paliate this, :class:`~ProbabilisticTensorDictModule` will first attempt
104+
to get the value through a call to ``get_mode()``, ``get_median()`` or ``get_mean()``
105+
if the method exists.
106+
107+
distribution_class (Type or Callable[[Any], Distribution], optional): keyword-only argument.
108+
A :class:`torch.distributions.Distribution` class to
109+
be used for sampling.
110+
Default is :class:`~tensordict.nn.distributions.Delta`.
111+
112+
.. note::
113+
If the distribution class is of type
114+
:class:`~tensordict.nn.distributions.CompositeDistribution`, the ``out_keys``
115+
can be inferred directly form the ``"distribution_map"`` or ``"name_map"``
116+
keywork arguments provided through this class' ``distribution_kwargs``
117+
keyword argument, making the ``out_keys`` optional in such cases.
118+
119+
distribution_kwargs (dict, optional): keyword-only argument.
120+
Keyword-argument pairs to be passed to the distribution.
121+
122+
.. note:: if your kwargs contain tensors that you would like to transfer to device with the module, or
123+
tensors that should see their dtype modified when calling `module.to(dtype)`, you can wrap the kwargs
124+
in a :class:`~tensordict.nn.TensorDictParams` to do this automatically.
125+
126+
return_log_prob (bool, optional): keyword-only argument.
127+
If ``True``, the log-probability of the
87128
distribution sample will be written in the tensordict with the key
88-
`'sample_log_prob'`. Default is ``False``.
89-
log_prob_key (NestedKey, optional): key where to write the log_prob if return_log_prob = True.
90-
Defaults to `"action_log_prob"`.
91-
cache_dist (bool, optional): EXPERIMENTAL: if ``True``, the parameters of the
129+
`log_prob_key`. Default is ``False``.
130+
log_prob_keys (List[NestedKey], optional): keys where to write the log_prob if ``return_log_prob=True``.
131+
Defaults to `'<sample_key_name>_log_prob'`, where `<sample_key_name>` is each of the :attr:`out_keys`.
132+
133+
.. note:: This is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``False``.
134+
135+
log_prob_key (NestedKey, optional): key where to write the log_prob if ``return_log_prob=True``.
136+
Defaults to `'sample_log_prob'` when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to `True`
137+
or `'<sample_key_name>_log_prob'` otherwise.
138+
139+
.. note:: When there is more than one sample, this is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``True``.
140+
141+
cache_dist (bool, optional): keyword-only argument.
142+
EXPERIMENTAL: if ``True``, the parameters of the
92143
distribution (i.e. the output of the module) will be written to the
93144
tensordict along with the sample. Those parameters can be used to re-compute
94145
the original distribution later on (e.g. to compute the divergence between
95146
the distribution used to sample the action and the updated distribution in
96147
PPO). Default is ``False``.
97-
n_empirical_estimate (int, optional): number of samples to compute the empirical
98-
mean when it is not available. Default is 1000
148+
n_empirical_estimate (int, optional): keyword-only argument.
149+
Number of samples to compute the empirical
150+
mean when it is not available. Defaults to 1000.
151+
152+
.. warning:: Running checks takes time! Using `safe=True` will guarantee that the samples are within the spec bounds
153+
given some heuristic coded in :meth:`~torchrl.data.TensorSpec.project`, but that requires checking whether the
154+
values are within the spec space, which will induce some overhead.
155+
156+
.. seealso:: :class`The composite distribution in tensordict <~tensordict.nn.CompositeDistribution>` can be used
157+
to create multi-head policies.
99158
159+
Example:
160+
>>> from torchrl.modules import SafeProbabilisticModule
161+
>>> from torchrl.data import Bounded
162+
>>> import torch
163+
>>> from tensordict import TensorDict
164+
>>> from tensordict.nn import InteractionType
165+
>>> mod = SafeProbabilisticModule(
166+
... in_keys=["loc", "scale"],
167+
... out_keys=["action"],
168+
... distribution_class=torch.distributions.Normal,
169+
... safe=True,
170+
... spec=Bounded(low=-1, high=1, shape=()),
171+
... default_interaction_type=InteractionType.RANDOM
172+
... )
173+
>>> _ = torch.manual_seed(0)
174+
>>> data = TensorDict(
175+
... loc=torch.zeros(10, requires_grad=True),
176+
... scale=torch.full((10,), 10.0),
177+
... batch_size=(10,))
178+
>>> data = mod(data)
179+
>>> print(data["action"]) # All actions are within bound
180+
tensor([ 1., -1., -1., 1., -1., -1., 1., 1., -1., -1.],
181+
grad_fn=<ClampBackward0>)
182+
>>> data["action"].mean().backward()
183+
>>> print(data["loc"].grad) # clamp anihilates gradients
184+
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
100185
"""
101186

102187
def __init__(
103188
self,
104-
in_keys: Union[NestedKey, List[NestedKey], Dict[str, NestedKey]],
105-
out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
189+
in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey],
190+
out_keys: NestedKey | List[NestedKey] | None = None,
106191
spec: Optional[TensorSpec] = None,
192+
*,
107193
safe: bool = False,
108-
default_interaction_type: str = InteractionType.DETERMINISTIC,
109-
distribution_class: Type = Delta,
110-
distribution_kwargs: Optional[dict] = None,
194+
default_interaction_type: InteractionType = InteractionType.DETERMINISTIC,
195+
distribution_class: type = Delta,
196+
distribution_kwargs: dict | None = None,
111197
return_log_prob: bool = False,
198+
log_prob_keys: List[NestedKey] | None = None,
112199
log_prob_key: NestedKey | None = None,
113200
cache_dist: bool = False,
114201
n_empirical_estimate: int = 1000,
202+
num_samples: int | torch.Size | None = None,
115203
):
116204
super().__init__(
117205
in_keys=in_keys,
@@ -120,9 +208,11 @@ def __init__(
120208
distribution_class=distribution_class,
121209
distribution_kwargs=distribution_kwargs,
122210
return_log_prob=return_log_prob,
123-
log_prob_key=log_prob_key,
124211
cache_dist=cache_dist,
125212
n_empirical_estimate=n_empirical_estimate,
213+
log_prob_keys=log_prob_keys,
214+
log_prob_key=log_prob_key,
215+
num_samples=num_samples,
126216
)
127217
if spec is not None:
128218
spec = spec.clone()

0 commit comments

Comments
 (0)