27
27
"EGreedyWrapper" ,
28
28
"EGreedyModule" ,
29
29
"AdditiveGaussianModule" ,
30
- "AdditiveGaussianWrapper" ,
31
30
"OrnsteinUhlenbeckProcessModule" ,
32
31
"OrnsteinUhlenbeckProcessWrapper" ,
33
32
]
@@ -220,42 +219,7 @@ def __init__(
220
219
221
220
222
221
class AdditiveGaussianWrapper (TensorDictModuleWrapper ):
223
- """Additive Gaussian PO wrapper.
224
-
225
- Args:
226
- policy (TensorDictModule): a policy.
227
-
228
- Keyword Args:
229
- sigma_init (scalar, optional): initial epsilon value.
230
- default: 1.0
231
- sigma_end (scalar, optional): final epsilon value.
232
- default: 0.1
233
- annealing_num_steps (int, optional): number of steps it will take for
234
- sigma to reach the :obj:`sigma_end` value.
235
- mean (:obj:`float`, optional): mean of each output element’s normal distribution.
236
- std (:obj:`float`, optional): standard deviation of each output element’s normal distribution.
237
- action_key (NestedKey, optional): if the policy module has more than one output key,
238
- its output spec will be of type Composite. One needs to know where to
239
- find the action spec.
240
- Default is "action".
241
- spec (TensorSpec, optional): if provided, the sampled action will be
242
- projected onto the valid action space once explored. If not provided,
243
- the exploration wrapper will attempt to recover it from the policy.
244
- safe (boolean, optional): if False, the TensorSpec can be None. If it
245
- is set to False but the spec is passed, the projection will still
246
- happen.
247
- Default is True.
248
- device (torch.device, optional): the device where the buffers have to be stored.
249
-
250
- .. note::
251
- Once an environment has been wrapped in :class:`AdditiveGaussianWrapper`, it is
252
- crucial to incorporate a call to :meth:`~.step` in the training loop
253
- to update the exploration factor.
254
- Since it is not easy to capture this omission no warning or exception
255
- will be raised if this is ommitted!
256
-
257
-
258
- """
222
+ """[Deprecated] Additive Gaussian PO wrapper."""
259
223
260
224
def __init__ (
261
225
self ,
@@ -271,105 +235,9 @@ def __init__(
271
235
safe : Optional [bool ] = True ,
272
236
device : torch .device | None = None ,
273
237
):
274
- warnings .warn (
275
- "AdditiveGaussianWrapper is deprecated and will be removed "
276
- "in v0.7. Please use torchrl.modules.AdditiveGaussianModule "
277
- "instead." ,
278
- category = DeprecationWarning ,
279
- )
280
- if device is None and hasattr (policy , "parameters" ):
281
- for p in policy .parameters ():
282
- device = p .device
283
- break
284
-
285
- super ().__init__ (policy )
286
- if sigma_end > sigma_init :
287
- raise RuntimeError ("sigma should decrease over time or be constant" )
288
- self .register_buffer ("sigma_init" , torch .tensor (sigma_init , device = device ))
289
- self .register_buffer ("sigma_end" , torch .tensor (sigma_end , device = device ))
290
- self .annealing_num_steps = annealing_num_steps
291
- self .register_buffer ("mean" , torch .tensor (mean , device = device ))
292
- self .register_buffer ("std" , torch .tensor (std , device = device ))
293
- self .register_buffer (
294
- "sigma" , torch .tensor (sigma_init , dtype = torch .float32 , device = device )
238
+ raise RuntimeError (
239
+ "This module has been removed from TorchRL. Please use torchrl.modules.AdditiveGaussianModule instead."
295
240
)
296
- self .action_key = action_key
297
- self .out_keys = list (self .td_module .out_keys )
298
- if action_key not in self .out_keys :
299
- raise RuntimeError (
300
- f"The action key { action_key } was not found in the td_module out_keys { self .td_module .out_keys } ."
301
- )
302
- if spec is not None :
303
- if not isinstance (spec , Composite ) and len (self .out_keys ) >= 1 :
304
- spec = Composite ({action_key : spec }, shape = spec .shape [:- 1 ])
305
- self ._spec = spec
306
- elif hasattr (self .td_module , "_spec" ):
307
- self ._spec = self .td_module ._spec .clone ()
308
- if action_key not in self ._spec .keys (True , True ):
309
- self ._spec [action_key ] = None
310
- elif hasattr (self .td_module , "spec" ):
311
- self ._spec = self .td_module .spec .clone ()
312
- if action_key not in self ._spec .keys (True , True ):
313
- self ._spec [action_key ] = None
314
- else :
315
- self ._spec = Composite ({key : None for key in policy .out_keys })
316
-
317
- self .safe = safe
318
- if self .safe :
319
- self .register_forward_hook (_forward_hook_safe_action )
320
-
321
- @property
322
- def spec (self ):
323
- return self ._spec
324
-
325
- def step (self , frames : int = 1 ) -> None :
326
- """A step of sigma decay.
327
-
328
- After self.annealing_num_steps, this function is a no-op.
329
-
330
- Args:
331
- frames (int): number of frames since last step.
332
-
333
- """
334
- for _ in range (frames ):
335
- self .sigma .data .copy_ (
336
- torch .maximum (
337
- self .sigma_end ,
338
- self .sigma
339
- - (self .sigma_init - self .sigma_end ) / self .annealing_num_steps ,
340
- ),
341
- )
342
-
343
- def _add_noise (self , action : torch .Tensor ) -> torch .Tensor :
344
- sigma = self .sigma
345
- mean = self .mean .expand (action .shape )
346
- std = self .std .expand (action .shape )
347
- if not mean .dtype .is_floating_point :
348
- mean = mean .to (torch .get_default_dtype ())
349
- if not std .dtype .is_floating_point :
350
- std = std .to (torch .get_default_dtype ())
351
- noise = torch .normal (mean = mean , std = std )
352
- if noise .device != action .device :
353
- noise = noise .to (action .device )
354
- action = action + noise * sigma
355
- spec = self .spec
356
- spec = spec [self .action_key ]
357
- if spec is not None :
358
- action = spec .project (action )
359
- elif self .safe :
360
- raise RuntimeError (
361
- "the action spec must be provided to AdditiveGaussianWrapper unless "
362
- "the `safe` keyword argument is turned off at initialization."
363
- )
364
- return action
365
-
366
- def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
367
- tensordict = self .td_module .forward (tensordict )
368
- if exploration_type () is ExplorationType .RANDOM or exploration_type () is None :
369
- out = tensordict .get (self .action_key )
370
- out = self ._add_noise (out )
371
- tensordict .set (self .action_key , out )
372
- return tensordict
373
241
374
242
375
243
class AdditiveGaussianModule (TensorDictModuleBase ):
0 commit comments