File tree Expand file tree Collapse file tree 3 files changed +24
-4
lines changed Expand file tree Collapse file tree 3 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -311,11 +311,12 @@ Utils
311
311
:toctree: generated/
312
312
:template: rl_template_noinherit.rst
313
313
314
+ HardUpdate
315
+ SoftUpdate
316
+ ValueEstimators
317
+ default_value_kwargs
314
318
distance_loss
319
+ group_optimizers
315
320
hold_out_net
316
321
hold_out_params
317
322
next_state_value
318
- SoftUpdate
319
- HardUpdate
320
- ValueEstimators
321
- default_value_kwargs
Original file line number Diff line number Diff line change 23
23
from .utils import (
24
24
default_value_kwargs ,
25
25
distance_loss ,
26
+ group_optimizers ,
26
27
HardUpdate ,
27
28
hold_out_net ,
28
29
hold_out_params ,
Original file line number Diff line number Diff line change @@ -590,3 +590,21 @@ def _clip_value_loss(
590
590
# Chose the most pessimistic value prediction between clipped and non-clipped
591
591
loss_value = torch .max (loss_value , loss_value_clipped )
592
592
return loss_value , clip_fraction
593
+
594
+
595
+ def group_optimizers (* optimizers : torch .optim .Optimizer ) -> torch .optim .Optimizer :
596
+ """Groups multiple optimizers into a single one.
597
+
598
+ All optimizers are expected to have the same type.
599
+ """
600
+ cls = None
601
+ params = []
602
+ for optimizer in optimizers :
603
+ if optimizer is None :
604
+ continue
605
+ if cls is None :
606
+ cls = type (optimizer )
607
+ if cls is not type (optimizer ):
608
+ raise ValueError ("Cannot group optimizers of different type." )
609
+ params .extend (optimizer .param_groups )
610
+ return cls (params )
You can’t perform that action at this time.
0 commit comments