29
29
30
30
__all__ = ["GAE" , "TDLambdaEstimate" , "TDEstimate" ]
31
31
32
+ from ..costs .utils import hold_out_net
33
+
32
34
33
35
class TDEstimate :
34
36
"""Temporal Difference estimate of advantage function.
@@ -89,7 +91,7 @@ def __call__(
89
91
if self .average_rewards :
90
92
reward = reward - reward .mean ()
91
93
reward = reward / reward .std ().clamp_min (1e-4 )
92
- tensordict .set_ (
94
+ tensordict .set (
93
95
"reward" , reward
94
96
) # we must update the rewards if they are used later in the code
95
97
@@ -106,12 +108,19 @@ def __call__(
106
108
self .value_network (tensordict , ** kwargs )
107
109
value = tensordict .get (self .value_key )
108
110
109
- with torch .set_grad_enabled (False ):
111
+ with hold_out_net (self .value_network ):
112
+ # we may still need to pass gradient, but we don't want to assign grads to
113
+ # value net params
110
114
step_td = step_tensordict (tensordict )
111
115
if target_params is not None :
116
+ # we assume that target parameters are not differentiable
112
117
kwargs ["params" ] = target_params
118
+ elif "params" in kwargs :
119
+ kwargs ["params" ] = [param .detach () for param in kwargs ["params" ]]
113
120
if target_buffers is not None :
114
121
kwargs ["buffers" ] = target_buffers
122
+ elif "buffers" in kwargs :
123
+ kwargs ["buffers" ] = [buffer .detach () for buffer in kwargs ["buffers" ]]
115
124
self .value_network (step_td , ** kwargs )
116
125
next_value = step_td .get (self .value_key )
117
126
@@ -190,7 +199,7 @@ def __call__(
190
199
if self .average_rewards :
191
200
reward = reward - reward .mean ()
192
201
reward = reward / reward .std ().clamp_min (1e-4 )
193
- tensordict .set_ (
202
+ tensordict .set (
194
203
"reward" , reward
195
204
) # we must update the rewards if they are used later in the code
196
205
@@ -209,12 +218,19 @@ def __call__(
209
218
self .value_network (tensordict , ** kwargs )
210
219
value = tensordict .get (self .value_key )
211
220
212
- with torch .set_grad_enabled (False ):
221
+ with hold_out_net (self .value_network ):
222
+ # we may still need to pass gradient, but we don't want to assign grads to
223
+ # value net params
213
224
step_td = step_tensordict (tensordict )
214
225
if target_params is not None :
226
+ # we assume that target parameters are not differentiable
215
227
kwargs ["params" ] = target_params
228
+ elif "params" in kwargs :
229
+ kwargs ["params" ] = [param .detach () for param in kwargs ["params" ]]
216
230
if target_buffers is not None :
217
231
kwargs ["buffers" ] = target_buffers
232
+ elif "buffers" in kwargs :
233
+ kwargs ["buffers" ] = [buffer .detach () for buffer in kwargs ["buffers" ]]
218
234
self .value_network (step_td , ** kwargs )
219
235
next_value = step_td .get (self .value_key )
220
236
@@ -295,7 +311,7 @@ def __call__(
295
311
if self .average_rewards :
296
312
reward = reward - reward .mean ()
297
313
reward = reward / reward .std ().clamp_min (1e-4 )
298
- tensordict .set_ (
314
+ tensordict .set (
299
315
"reward" , reward
300
316
) # we must update the rewards if they are used later in the code
301
317
@@ -312,12 +328,19 @@ def __call__(
312
328
self .value_network (tensordict , ** kwargs )
313
329
value = tensordict .get ("state_value" )
314
330
315
- with torch .set_grad_enabled (False ):
331
+ with hold_out_net (self .value_network ):
332
+ # we may still need to pass gradient, but we don't want to assign grads to
333
+ # value net params
316
334
step_td = step_tensordict (tensordict )
317
335
if target_params is not None :
336
+ # we assume that target parameters are not differentiable
318
337
kwargs ["params" ] = target_params
338
+ elif "params" in kwargs :
339
+ kwargs ["params" ] = [param .detach () for param in kwargs ["params" ]]
319
340
if target_buffers is not None :
320
341
kwargs ["buffers" ] = target_buffers
342
+ elif "buffers" in kwargs :
343
+ kwargs ["buffers" ] = [buffer .detach () for buffer in kwargs ["buffers" ]]
321
344
self .value_network (step_td , ** kwargs )
322
345
next_value = step_td .get ("state_value" )
323
346
done = tensordict .get ("done" )
0 commit comments