@@ -370,94 +370,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
370
370
371
371
372
372
class OrnsteinUhlenbeckProcessWrapper (TensorDictModuleWrapper ):
373
- r"""Ornstein-Uhlenbeck exploration policy wrapper.
374
-
375
- Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf.
376
-
377
- The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration
378
- noise. This enables a sort of 'structured' exploration.
379
-
380
- Noise equation:
381
-
382
- .. math::
383
- noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W
384
-
385
- Sigma equation:
386
-
387
- .. math::
388
- \sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma))
389
-
390
- To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys
391
- will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset,
392
- indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive
393
- trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of
394
- zeroing the tensordict at reset time.
395
-
396
- .. note::
397
- Once an environment has been wrapped in :class:`OrnsteinUhlenbeckProcessWrapper`, it is
398
- crucial to incorporate a call to :meth:`~.step` in the training loop
399
- to update the exploration factor.
400
- Since it is not easy to capture this omission no warning or exception
401
- will be raised if this is ommitted!
402
-
403
- Args:
404
- policy (TensorDictModule): a policy
405
-
406
- Keyword Args:
407
- eps_init (scalar): initial epsilon value, determining the amount of noise to be added.
408
- default: 1.0
409
- eps_end (scalar): final epsilon value, determining the amount of noise to be added.
410
- default: 0.1
411
- annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value.
412
- default: 1000
413
- theta (scalar): theta factor in the noise equation
414
- default: 0.15
415
- mu (scalar): OU average (mu in the noise equation).
416
- default: 0.0
417
- sigma (scalar): sigma value in the sigma equation.
418
- default: 0.2
419
- dt (scalar): dt in the noise equation.
420
- default: 0.01
421
- x0 (Tensor, ndarray, optional): initial value of the process.
422
- default: 0.0
423
- sigma_min (number, optional): sigma_min in the sigma equation.
424
- default: None
425
- n_steps_annealing (int): number of steps for the sigma annealing.
426
- default: 1000
427
- action_key (NestedKey, optional): key of the action to be modified.
428
- default: "action"
429
- is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
430
- default: "is_init"
431
- spec (TensorSpec, optional): if provided, the sampled action will be
432
- projected onto the valid action space once explored. If not provided,
433
- the exploration wrapper will attempt to recover it from the policy.
434
- safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
435
- given the :obj:`TensorSpec.project` heuristic.
436
- default: True
437
- device (torch.device, optional): the device where the buffers have to be stored.
438
-
439
- Examples:
440
- >>> import torch
441
- >>> from tensordict import TensorDict
442
- >>> from torchrl.data import Bounded
443
- >>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor
444
- >>> torch.manual_seed(0)
445
- >>> spec = Bounded(-1, 1, torch.Size([4]))
446
- >>> module = torch.nn.Linear(4, 4, bias=False)
447
- >>> policy = Actor(module=module, spec=spec)
448
- >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy)
449
- >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
450
- >>> print(explorative_policy(td))
451
- TensorDict(
452
- fields={
453
- _ou_prev_noise: Tensor(torch.Size([10, 4]), dtype=torch.float32),
454
- _ou_steps: Tensor(torch.Size([10, 1]), dtype=torch.int64),
455
- action: Tensor(torch.Size([10, 4]), dtype=torch.float32),
456
- observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
457
- batch_size=torch.Size([10]),
458
- device=None,
459
- is_shared=False)
460
- """
373
+ """[Deprecated] Ornstein-Uhlenbeck exploration policy wrapper."""
461
374
462
375
def __init__ (
463
376
self ,
@@ -480,119 +393,9 @@ def __init__(
480
393
key : Optional [NestedKey ] = None ,
481
394
device : torch .device | None = None ,
482
395
):
483
- warnings .warn (
484
- "OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed "
485
- "in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule "
486
- "instead." ,
487
- category = DeprecationWarning ,
488
- )
489
- if device is None and hasattr (policy , "parameters" ):
490
- for p in policy .parameters ():
491
- device = p .device
492
- break
493
- if key is not None :
494
- action_key = key
495
- warnings .warn (
496
- f"the 'key' keyword argument of { type (self )} has been renamed 'action_key'. The 'key' entry will be deprecated soon."
497
- )
498
- super ().__init__ (policy )
499
- self .ou = _OrnsteinUhlenbeckProcess (
500
- theta = theta ,
501
- mu = mu ,
502
- sigma = sigma ,
503
- dt = dt ,
504
- x0 = x0 ,
505
- sigma_min = sigma_min ,
506
- n_steps_annealing = n_steps_annealing ,
507
- key = action_key ,
508
- device = device ,
509
- )
510
- self .register_buffer ("eps_init" , torch .tensor (eps_init , device = device ))
511
- self .register_buffer ("eps_end" , torch .tensor (eps_end , device = device ))
512
- if self .eps_end > self .eps_init :
513
- raise ValueError (
514
- "eps should decrease over time or be constant, "
515
- f"got eps_init={ eps_init } and eps_end={ eps_end } "
516
- )
517
- self .annealing_num_steps = annealing_num_steps
518
- self .register_buffer (
519
- "eps" , torch .tensor (eps_init , dtype = torch .float32 , device = device )
396
+ raise RuntimeError (
397
+ "OrnsteinUhlenbeckProcessWrapper has been removed. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule instead."
520
398
)
521
- self .out_keys = list (self .td_module .out_keys ) + self .ou .out_keys
522
- self .is_init_key = is_init_key
523
- noise_key = self .ou .noise_key
524
- steps_key = self .ou .steps_key
525
-
526
- if spec is not None :
527
- if not isinstance (spec , Composite ) and len (self .out_keys ) >= 1 :
528
- spec = Composite ({action_key : spec }, shape = spec .shape [:- 1 ])
529
- self ._spec = spec
530
- elif hasattr (self .td_module , "_spec" ):
531
- self ._spec = self .td_module ._spec .clone ()
532
- if action_key not in self ._spec .keys (True , True ):
533
- self ._spec [action_key ] = None
534
- elif hasattr (self .td_module , "spec" ):
535
- self ._spec = self .td_module .spec .clone ()
536
- if action_key not in self ._spec .keys (True , True ):
537
- self ._spec [action_key ] = None
538
- else :
539
- self ._spec = Composite ({key : None for key in policy .out_keys })
540
- ou_specs = {
541
- noise_key : None ,
542
- steps_key : None ,
543
- }
544
- self ._spec .update (ou_specs )
545
- if len (set (self .out_keys )) != len (self .out_keys ):
546
- raise RuntimeError (f"Got multiple identical output keys: { self .out_keys } " )
547
- self .safe = safe
548
- if self .safe :
549
- self .register_forward_hook (_forward_hook_safe_action )
550
-
551
- @property
552
- def spec (self ):
553
- return self ._spec
554
-
555
- def step (self , frames : int = 1 ) -> None :
556
- """Updates the eps noise factor.
557
-
558
- Args:
559
- frames (int): number of frames of the current batch (corresponding to the number of updates to be made).
560
-
561
- """
562
- for _ in range (frames ):
563
- if self .annealing_num_steps > 0 :
564
- self .eps .data .copy_ (
565
- torch .maximum (
566
- self .eps_end ,
567
- (
568
- self .eps
569
- - (self .eps_init - self .eps_end ) / self .annealing_num_steps
570
- ),
571
- )
572
- )
573
- else :
574
- raise ValueError (
575
- f"{ self .__class__ .__name__ } .step() called when "
576
- f"self.annealing_num_steps={ self .annealing_num_steps } . Expected a strictly positive "
577
- f"number of frames."
578
- )
579
-
580
- def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
581
- tensordict = super ().forward (tensordict )
582
- if exploration_type () == ExplorationType .RANDOM or exploration_type () is None :
583
- is_init = tensordict .get (self .is_init_key , None )
584
- if is_init is None :
585
- warnings .warn (
586
- f"The tensordict passed to { self .__class__ .__name__ } appears to be "
587
- f"missing the '{ self .is_init_key } ' entry. This entry is used to "
588
- f"reset the noise at the beginning of a trajectory, without it "
589
- f"the behavior of this exploration method is undefined. "
590
- f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
591
- f"To create a '{ self .is_init_key } ' entry, simply append an torchrl.envs.InitTracker "
592
- f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
593
- )
594
- tensordict = self .ou .add_sample (tensordict , self .eps , is_init = is_init )
595
- return tensordict
596
399
597
400
598
401
class OrnsteinUhlenbeckProcessModule (TensorDictModuleBase ):
0 commit comments