Skip to content

Commit 25dca6e

Browse files
authored
[BugFix]: gradient propagation in advantage estimates (#322)
1 parent 7c43aef commit 25dca6e

File tree

3 files changed

+33
-16
lines changed

3 files changed

+33
-16
lines changed

test/test_cost.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,13 +1573,9 @@ def test_hold_out():
15731573
assert y.requires_grad
15741574

15751575
# exception
1576-
with pytest.raises(
1577-
RuntimeError,
1578-
match="hold_out_net requires the network parameter set to be non-empty.",
1579-
):
1580-
net = torch.nn.Sequential()
1581-
with hold_out_net(net):
1582-
pass
1576+
net = torch.nn.Sequential()
1577+
with hold_out_net(net):
1578+
pass
15831579

15841580

15851581
@pytest.mark.parametrize("mode", ["hard", "soft"])

torchrl/objectives/costs/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ def __init__(self, network: nn.Module) -> None:
263263
try:
264264
self.p_example = next(network.parameters())
265265
except StopIteration:
266-
raise RuntimeError(
267-
"hold_out_net requires the network parameter set to be " "non-empty."
268-
)
266+
self.p_example = torch.tensor([])
269267
self._prev_state = []
270268

271269
def __enter__(self) -> None:

torchrl/objectives/returns/advantages.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
__all__ = ["GAE", "TDLambdaEstimate", "TDEstimate"]
3131

32+
from ..costs.utils import hold_out_net
33+
3234

3335
class TDEstimate:
3436
"""Temporal Difference estimate of advantage function.
@@ -89,7 +91,7 @@ def __call__(
8991
if self.average_rewards:
9092
reward = reward - reward.mean()
9193
reward = reward / reward.std().clamp_min(1e-4)
92-
tensordict.set_(
94+
tensordict.set(
9395
"reward", reward
9496
) # we must update the rewards if they are used later in the code
9597

@@ -106,12 +108,19 @@ def __call__(
106108
self.value_network(tensordict, **kwargs)
107109
value = tensordict.get(self.value_key)
108110

109-
with torch.set_grad_enabled(False):
111+
with hold_out_net(self.value_network):
112+
# we may still need to pass gradient, but we don't want to assign grads to
113+
# value net params
110114
step_td = step_tensordict(tensordict)
111115
if target_params is not None:
116+
# we assume that target parameters are not differentiable
112117
kwargs["params"] = target_params
118+
elif "params" in kwargs:
119+
kwargs["params"] = [param.detach() for param in kwargs["params"]]
113120
if target_buffers is not None:
114121
kwargs["buffers"] = target_buffers
122+
elif "buffers" in kwargs:
123+
kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]]
115124
self.value_network(step_td, **kwargs)
116125
next_value = step_td.get(self.value_key)
117126

@@ -190,7 +199,7 @@ def __call__(
190199
if self.average_rewards:
191200
reward = reward - reward.mean()
192201
reward = reward / reward.std().clamp_min(1e-4)
193-
tensordict.set_(
202+
tensordict.set(
194203
"reward", reward
195204
) # we must update the rewards if they are used later in the code
196205

@@ -209,12 +218,19 @@ def __call__(
209218
self.value_network(tensordict, **kwargs)
210219
value = tensordict.get(self.value_key)
211220

212-
with torch.set_grad_enabled(False):
221+
with hold_out_net(self.value_network):
222+
# we may still need to pass gradient, but we don't want to assign grads to
223+
# value net params
213224
step_td = step_tensordict(tensordict)
214225
if target_params is not None:
226+
# we assume that target parameters are not differentiable
215227
kwargs["params"] = target_params
228+
elif "params" in kwargs:
229+
kwargs["params"] = [param.detach() for param in kwargs["params"]]
216230
if target_buffers is not None:
217231
kwargs["buffers"] = target_buffers
232+
elif "buffers" in kwargs:
233+
kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]]
218234
self.value_network(step_td, **kwargs)
219235
next_value = step_td.get(self.value_key)
220236

@@ -295,7 +311,7 @@ def __call__(
295311
if self.average_rewards:
296312
reward = reward - reward.mean()
297313
reward = reward / reward.std().clamp_min(1e-4)
298-
tensordict.set_(
314+
tensordict.set(
299315
"reward", reward
300316
) # we must update the rewards if they are used later in the code
301317

@@ -312,12 +328,19 @@ def __call__(
312328
self.value_network(tensordict, **kwargs)
313329
value = tensordict.get("state_value")
314330

315-
with torch.set_grad_enabled(False):
331+
with hold_out_net(self.value_network):
332+
# we may still need to pass gradient, but we don't want to assign grads to
333+
# value net params
316334
step_td = step_tensordict(tensordict)
317335
if target_params is not None:
336+
# we assume that target parameters are not differentiable
318337
kwargs["params"] = target_params
338+
elif "params" in kwargs:
339+
kwargs["params"] = [param.detach() for param in kwargs["params"]]
319340
if target_buffers is not None:
320341
kwargs["buffers"] = target_buffers
342+
elif "buffers" in kwargs:
343+
kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]]
321344
self.value_network(step_td, **kwargs)
322345
next_value = step_td.get("state_value")
323346
done = tensordict.get("done")

0 commit comments

Comments
 (0)