File tree Expand file tree Collapse file tree 4 files changed +20
-19
lines changed Expand file tree Collapse file tree 4 files changed +20
-19
lines changed Original file line number Diff line number Diff line change 18
18
from torch .nn import Parameter
19
19
20
20
from torchrl ._utils import RL_WARNINGS
21
+ from torchrl .envs .utils import ExplorationType , set_exploration_type
21
22
from torchrl .modules .utils import Buffer
22
23
from torchrl .objectives .utils import ValueEstimators
23
24
from torchrl .objectives .value import ValueEstimatorBase
@@ -53,11 +54,18 @@ class LossModule(nn.Module):
53
54
pointer. This class attribute indicates which value estimator will be
54
55
used if none other is specified.
55
56
The value estimator can be changed using the :meth:`~.make_value_estimator` method.
57
+
58
+ By default, the forward method is always decorated with a
59
+ gh :class:`torchrl.envs.ExplorationType.MODE`
56
60
"""
57
61
58
62
default_value_estimator : ValueEstimators = None
59
63
SEP = "_sep_"
60
64
65
+ def __new__ (cls , * args , ** kwargs ):
66
+ cls .forward = set_exploration_type (ExplorationType .MODE )(cls .forward )
67
+ return super ().__new__ (cls )
68
+
61
69
def __init__ (self ):
62
70
super ().__init__ ()
63
71
self ._param_maps = {}
Original file line number Diff line number Diff line change 14
14
from tensordict .nn import make_functional , repopulate_module , TensorDictModule
15
15
from tensordict .tensordict import TensorDict , TensorDictBase
16
16
17
- from torchrl .envs .utils import ExplorationType , set_exploration_type
18
-
19
17
from torchrl .modules .tensordict_module .actors import ActorCriticWrapper
20
18
from torchrl .objectives .common import LossModule
21
19
from torchrl .objectives .utils import (
@@ -162,10 +160,9 @@ def _loss_value(
162
160
batch_size = self .target_actor_network_params .batch_size ,
163
161
device = self .target_actor_network_params .device ,
164
162
)
165
- with set_exploration_type (ExplorationType .MODE ):
166
- target_value = self .value_estimator .value_estimate (
167
- tensordict , target_params = target_params
168
- ).squeeze (- 1 )
163
+ target_value = self .value_estimator .value_estimate (
164
+ tensordict , target_params = target_params
165
+ ).squeeze (- 1 )
169
166
170
167
# td_error = pred_val - target_value
171
168
loss_value = distance_loss (
Original file line number Diff line number Diff line change 10
10
from tensordict .tensordict import TensorDict , TensorDictBase
11
11
from torch import Tensor
12
12
13
- from torchrl .envs .utils import ExplorationType , set_exploration_type
14
-
15
13
from torchrl .modules import ProbabilisticActor
16
14
from torchrl .objectives .common import LossModule
17
15
from torchrl .objectives .utils import (
@@ -170,11 +168,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
170
168
171
169
def _loss_actor (self , tensordict : TensorDictBase ) -> Tensor :
172
170
# KL loss
173
- with set_exploration_type (ExplorationType .MODE ):
174
- dist = self .actor_network .get_dist (
175
- tensordict ,
176
- params = self .actor_network_params ,
177
- )
171
+ dist = self .actor_network .get_dist (
172
+ tensordict ,
173
+ params = self .actor_network_params ,
174
+ )
178
175
179
176
log_prob = dist .log_prob (tensordict ["action" ])
180
177
Original file line number Diff line number Diff line change 9
9
10
10
from tensordict .tensordict import TensorDict , TensorDictBase
11
11
12
- from torchrl .envs .utils import ExplorationType , set_exploration_type , step_mdp
12
+ from torchrl .envs .utils import step_mdp
13
13
from torchrl .objectives .common import LossModule
14
14
from torchrl .objectives .utils import (
15
15
_GAMMA_LMBDA_DEPREC_WARNING ,
@@ -127,11 +127,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
127
127
tensordict_actor = torch .stack ([tensordict_actor_grad , next_td_actor ], 0 )
128
128
tensordict_actor = tensordict_actor .contiguous ()
129
129
130
- with set_exploration_type (ExplorationType .MODE ):
131
- actor_output_td = vmap (self .actor_network )(
132
- tensordict_actor ,
133
- actor_params ,
134
- )
130
+ actor_output_td = vmap (self .actor_network )(
131
+ tensordict_actor ,
132
+ actor_params ,
133
+ )
135
134
# add noise to target policy
136
135
noise = torch .normal (
137
136
mean = torch .zeros (actor_output_td [1 ]["action" ].shape ),
You can’t perform that action at this time.
0 commit comments