Skip to content

Commit 714d645

Browse files
authored
[BugFix] Set exploration mode to MODE in all losses by default (#1123)
1 parent 257f152 commit 714d645

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

torchrl/objectives/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.nn import Parameter
1919

2020
from torchrl._utils import RL_WARNINGS
21+
from torchrl.envs.utils import ExplorationType, set_exploration_type
2122
from torchrl.modules.utils import Buffer
2223
from torchrl.objectives.utils import ValueEstimators
2324
from torchrl.objectives.value import ValueEstimatorBase
@@ -53,11 +54,18 @@ class LossModule(nn.Module):
5354
pointer. This class attribute indicates which value estimator will be
5455
used if none other is specified.
5556
The value estimator can be changed using the :meth:`~.make_value_estimator` method.
57+
58+
By default, the forward method is always decorated with a
59+
gh :class:`torchrl.envs.ExplorationType.MODE`
5660
"""
5761

5862
default_value_estimator: ValueEstimators = None
5963
SEP = "_sep_"
6064

65+
def __new__(cls, *args, **kwargs):
66+
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
67+
return super().__new__(cls)
68+
6169
def __init__(self):
6270
super().__init__()
6371
self._param_maps = {}

torchrl/objectives/ddpg.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from tensordict.nn import make_functional, repopulate_module, TensorDictModule
1515
from tensordict.tensordict import TensorDict, TensorDictBase
1616

17-
from torchrl.envs.utils import ExplorationType, set_exploration_type
18-
1917
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
2018
from torchrl.objectives.common import LossModule
2119
from torchrl.objectives.utils import (
@@ -162,10 +160,9 @@ def _loss_value(
162160
batch_size=self.target_actor_network_params.batch_size,
163161
device=self.target_actor_network_params.device,
164162
)
165-
with set_exploration_type(ExplorationType.MODE):
166-
target_value = self.value_estimator.value_estimate(
167-
tensordict, target_params=target_params
168-
).squeeze(-1)
163+
target_value = self.value_estimator.value_estimate(
164+
tensordict, target_params=target_params
165+
).squeeze(-1)
169166

170167
# td_error = pred_val - target_value
171168
loss_value = distance_loss(

torchrl/objectives/iql.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from tensordict.tensordict import TensorDict, TensorDictBase
1111
from torch import Tensor
1212

13-
from torchrl.envs.utils import ExplorationType, set_exploration_type
14-
1513
from torchrl.modules import ProbabilisticActor
1614
from torchrl.objectives.common import LossModule
1715
from torchrl.objectives.utils import (
@@ -170,11 +168,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
170168

171169
def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:
172170
# KL loss
173-
with set_exploration_type(ExplorationType.MODE):
174-
dist = self.actor_network.get_dist(
175-
tensordict,
176-
params=self.actor_network_params,
177-
)
171+
dist = self.actor_network.get_dist(
172+
tensordict,
173+
params=self.actor_network_params,
174+
)
178175

179176
log_prob = dist.log_prob(tensordict["action"])
180177

torchrl/objectives/td3.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from tensordict.tensordict import TensorDict, TensorDictBase
1111

12-
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
12+
from torchrl.envs.utils import step_mdp
1313
from torchrl.objectives.common import LossModule
1414
from torchrl.objectives.utils import (
1515
_GAMMA_LMBDA_DEPREC_WARNING,
@@ -127,11 +127,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
127127
tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0)
128128
tensordict_actor = tensordict_actor.contiguous()
129129

130-
with set_exploration_type(ExplorationType.MODE):
131-
actor_output_td = vmap(self.actor_network)(
132-
tensordict_actor,
133-
actor_params,
134-
)
130+
actor_output_td = vmap(self.actor_network)(
131+
tensordict_actor,
132+
actor_params,
133+
)
135134
# add noise to target policy
136135
noise = torch.normal(
137136
mean=torch.zeros(actor_output_td[1]["action"].shape),

0 commit comments

Comments
 (0)