You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
""":class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as argument to control the output domain.
28
+
""":class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as an argument to control the output domain.
29
+
30
+
`SafeProbabilisticModule` is a non-parametric module embedding a
31
+
probability distribution constructor. It reads the distribution parameters from an input
32
+
TensorDict using the specified `in_keys` and outputs a sample (loosely speaking) of the
33
+
distribution.
34
+
35
+
The output "sample" is produced given some rule, specified by the input ``default_interaction_type``
36
+
argument and the ``interaction_type()`` global function.
27
37
28
-
`SafeProbabilisticModule` is a non-parametric module representing a
29
-
probability distribution. It reads the distribution parameters from an input
30
-
TensorDict using the specified `in_keys`. The output is sampled given some rule,
31
-
specified by the input ``default_interaction_type`` argument and the
32
-
``interaction_type()`` global function.
38
+
`SafeProbabilisticModule` can be used to construct the distribution
39
+
(through the :meth:`~.get_dist` method) and/or sampling from this distribution
40
+
(through a regular :meth:`~.__call__` to the module).
33
41
34
-
:obj:`SafeProbabilisticModule` can be used to construct the distribution
35
-
(through the :obj:`get_dist()` method) and/or sampling from this distribution
36
-
(through a regular :obj:`__call__()` to the module).
42
+
A `SafeProbabilisticModule` instance has two main features:
37
43
38
-
A :obj:`SafeProbabilisticModule` instance has two main features:
39
-
- It reads and writes TensorDict objects
44
+
- It reads and writes from and to TensorDict objects;
40
45
- It uses a real mapping R^n -> R^m to create a distribution in R^d from
41
-
which values can be sampled or computed.
46
+
which values can be sampled or computed.
42
47
43
-
When the :obj:`__call__` / :obj:`forward` method is called, a distribution is
44
-
created, and a value computed (using the 'mean', 'mode', 'median' attribute or
45
-
the 'rsample', 'sample' method). The sampling step is skipped if the supplied
46
-
TensorDict has all of the desired key-value pairs already.
48
+
When the :meth:`~.__call__` and :meth:`~.forward` method are called, a distribution is
49
+
created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean',
50
+
'dist.mode', 'dist.median' attributes could be used, as well as
51
+
the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied
52
+
TensorDict has all the desired key-value pairs already.
47
53
48
-
By default, SafeProbabilisticModule distribution class is a Delta
49
-
distribution, making SafeProbabilisticModule a simple wrapper around
54
+
By default, `SafeProbabilisticModule` distribution class is a :class:`~torchrl.modules.distributions.Delta`
55
+
distribution, making `SafeProbabilisticModule` a simple wrapper around
50
56
a deterministic mapping function.
51
57
58
+
This class differs from :class:`tensordict.nn.ProbabilisticTensorDictModule` in that it accepts a :attr:`spec`
59
+
keyword argument which can be used to control whether samples belong to the distribution or not. The :attr:`safe`
60
+
keyword argument controls whether the samples values should be checked against the spec.
61
+
52
62
Args:
53
-
in_keys (NestedKey or list of NestedKey or dict): key(s) that will be read from the
54
-
input TensorDict and used to build the distribution. Importantly, if it's an
55
-
list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
56
-
the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for
57
-
the Normal distribution and similar. If in_keys is a dictionary, the keys
58
-
are the keys of the distribution and the values are the keys in the
63
+
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict
64
+
and used to build the distribution.
65
+
Importantly, if it's a list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
66
+
the distribution class of interest, e.g. ``"loc"`` and ``"scale"`` for
67
+
the :class:`~torch.distributions.Normal` distribution and similar.
68
+
If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the
59
69
tensordict that will get match to the corresponding distribution keys.
60
-
out_keys (NestedKey or list of NestedKey): keys where the sampled values will be
61
-
written. Importantly, if these keys are found in the input TensorDict, the
62
-
sampling step will be skipped.
70
+
out_keys (NestedKey | List[NestedKey] | None): key(s) where the sampled values will be written.
71
+
Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.
63
72
spec (TensorSpec): specs of the first output tensor. Used when calling
64
73
td_module.random() to generate random values in the target space.
74
+
75
+
Keyword Args:
65
76
safe (bool, optional): if ``True``, the value of the sample is checked against the
66
77
input spec. Out-of-domain sampling can occur because of exploration policies
67
78
or numerical under/overflow issues. As for the :obj:`spec` argument, this
68
79
check will only occur for the distribution sample, but not the other tensors
69
80
returned by the input module. If the sample is out of bounds, it is
70
81
projected back onto the desired space using the `TensorSpec.project` method.
71
82
Default is ``False``.
72
-
default_interaction_type (tensordict.nn.InteractionType, optional): default method to be used to retrieve
73
-
the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or ``InteractionType.RANDOM``
distribution sample will be written in the tensordict with the key
88
-
`'sample_log_prob'`. Default is ``False``.
89
-
log_prob_key (NestedKey, optional): key where to write the log_prob if return_log_prob = True.
90
-
Defaults to `"action_log_prob"`.
91
-
cache_dist (bool, optional): EXPERIMENTAL: if ``True``, the parameters of the
129
+
`log_prob_key`. Default is ``False``.
130
+
log_prob_keys (List[NestedKey], optional): keys where to write the log_prob if ``return_log_prob=True``.
131
+
Defaults to `'<sample_key_name>_log_prob'`, where `<sample_key_name>` is each of the :attr:`out_keys`.
132
+
133
+
.. note:: This is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``False``.
134
+
135
+
log_prob_key (NestedKey, optional): key where to write the log_prob if ``return_log_prob=True``.
136
+
Defaults to `'sample_log_prob'` when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to `True`
137
+
or `'<sample_key_name>_log_prob'` otherwise.
138
+
139
+
.. note:: When there is more than one sample, this is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``True``.
0 commit comments