Skip to content

Commit 7829bd3

Browse files
author
Vincent Moens
committed
[Minor,Feature] group_optimizers
ghstack-source-id: 81a94ed Pull Request resolved: #2577
1 parent 7bc84d1 commit 7829bd3

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

docs/source/reference/objectives.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,12 @@ Utils
311311
:toctree: generated/
312312
:template: rl_template_noinherit.rst
313313

314+
HardUpdate
315+
SoftUpdate
316+
ValueEstimators
317+
default_value_kwargs
314318
distance_loss
319+
group_optimizers
315320
hold_out_net
316321
hold_out_params
317322
next_state_value
318-
SoftUpdate
319-
HardUpdate
320-
ValueEstimators
321-
default_value_kwargs

torchrl/objectives/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .utils import (
2424
default_value_kwargs,
2525
distance_loss,
26+
group_optimizers,
2627
HardUpdate,
2728
hold_out_net,
2829
hold_out_params,

torchrl/objectives/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,21 @@ def _clip_value_loss(
590590
# Chose the most pessimistic value prediction between clipped and non-clipped
591591
loss_value = torch.max(loss_value, loss_value_clipped)
592592
return loss_value, clip_fraction
593+
594+
595+
def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:
596+
"""Groups multiple optimizers into a single one.
597+
598+
All optimizers are expected to have the same type.
599+
"""
600+
cls = None
601+
params = []
602+
for optimizer in optimizers:
603+
if optimizer is None:
604+
continue
605+
if cls is None:
606+
cls = type(optimizer)
607+
if cls is not type(optimizer):
608+
raise ValueError("Cannot group optimizers of different type.")
609+
params.extend(optimizer.param_groups)
610+
return cls(params)

0 commit comments

Comments
 (0)