Skip to content

Commit 607ebc5

Browse files
authored
[Refactor] Rename Recorder and LogReward (#2616)
1 parent d22266d commit 607ebc5

File tree

8 files changed

+94
-38
lines changed

8 files changed

+94
-38
lines changed

docs/source/reference/trainers.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
7878
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.
7979

8080
- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
81-
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
82-
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
81+
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
82+
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
8383
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
8484
should be displayed on the progression bar printed on the training log.
8585

@@ -174,9 +174,9 @@ Trainer and hooks
174174
BatchSubSampler
175175
ClearCudaCache
176176
CountFramesLog
177-
LogReward
177+
LogScaler
178178
OptimizerHook
179-
Recorder
179+
LogValidationReward
180180
ReplayBufferTrainer
181181
RewardNormalizer
182182
SelectKeys

sota-implementations/redq/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@
8181
BatchSubSampler,
8282
ClearCudaCache,
8383
CountFramesLog,
84-
LogReward,
85-
Recorder,
84+
LogScalar,
85+
LogValidationReward,
8686
ReplayBufferTrainer,
8787
RewardNormalizer,
8888
Trainer,
@@ -331,7 +331,7 @@ def make_trainer(
331331

332332
if recorder is not None:
333333
# create recorder object
334-
recorder_obj = Recorder(
334+
recorder_obj = LogValidationReward(
335335
record_frames=cfg.logger.record_frames,
336336
frame_skip=cfg.env.frame_skip,
337337
policy_exploration=policy_exploration,
@@ -347,7 +347,7 @@ def make_trainer(
347347
# call recorder - could be removed
348348
recorder_obj(None)
349349
# create explorative recorder - could be optional
350-
recorder_obj_explore = Recorder(
350+
recorder_obj_explore = LogValidationReward(
351351
record_frames=cfg.logger.record_frames,
352352
frame_skip=cfg.env.frame_skip,
353353
policy_exploration=policy_exploration,
@@ -369,7 +369,7 @@ def make_trainer(
369369
"post_steps", UpdateWeights(collector, update_weights_interval=1)
370370
)
371371

372-
trainer.register_op("pre_steps_log", LogReward())
372+
trainer.register_op("pre_steps_log", LogScalar())
373373
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))
374374

375375
return trainer

test/test_trainer.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
TensorDictReplayBuffer,
3636
)
3737
from torchrl.envs.libs.gym import _has_gym
38-
from torchrl.trainers import Recorder, Trainer
38+
from torchrl.trainers import LogValidationReward, Trainer
3939
from torchrl.trainers.helpers import transformed_env_constructor
4040
from torchrl.trainers.trainers import (
4141
_has_tqdm,
4242
_has_ts,
4343
BatchSubSampler,
4444
CountFramesLog,
45-
LogReward,
45+
LogScalar,
4646
mask_batch,
4747
OptimizerHook,
4848
ReplayBufferTrainer,
@@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar):
638638
trainer = mocking_trainer()
639639
trainer.collected_frames = 0
640640

641-
log_reward = LogReward(logname, log_pbar=pbar)
641+
log_reward = LogScalar(logname, log_pbar=pbar)
642642
trainer.register_op("pre_steps_log", log_reward)
643643
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
644644
trainer._pre_steps_log_hook(td)
@@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar):
654654
trainer = mocking_trainer()
655655
trainer.collected_frames = 0
656656

657-
log_reward = LogReward(logname, log_pbar=pbar)
657+
log_reward = LogScalar(logname, log_pbar=pbar)
658658
log_reward.register(trainer)
659659
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
660660
trainer._pre_steps_log_hook(td)
@@ -873,7 +873,7 @@ def test_recorder(self, N=8):
873873
logger=logger,
874874
)()
875875

876-
recorder = Recorder(
876+
recorder = LogValidationReward(
877877
record_frames=args.record_frames,
878878
frame_skip=args.frame_skip,
879879
policy_exploration=None,
@@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8):
919919
os.environ["CKPT_BACKEND"] = backend
920920
state_dict_has_been_called = [False]
921921
load_state_dict_has_been_called = [False]
922-
Recorder.state_dict, Recorder_state_dict = _fun_checker(
923-
Recorder.state_dict, state_dict_has_been_called
922+
LogValidationReward.state_dict, Recorder_state_dict = _fun_checker(
923+
LogValidationReward.state_dict, state_dict_has_been_called
924+
)
925+
(LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker(
926+
LogValidationReward.load_state_dict, load_state_dict_has_been_called
924927
)
925-
(
926-
Recorder.load_state_dict,
927-
Recorder_load_state_dict,
928-
) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called)
929928

930929
args = self._get_args()
931930

@@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
948947
)()
949948
environment.rollout(2)
950949

951-
recorder = Recorder(
950+
recorder = LogValidationReward(
952951
record_frames=args.record_frames,
953952
frame_skip=args.frame_skip,
954953
policy_exploration=None,
@@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname):
969968
assert recorder2._count == 8
970969
assert state_dict_has_been_called[0]
971970
assert load_state_dict_has_been_called[0]
972-
Recorder.state_dict = Recorder_state_dict
973-
Recorder.load_state_dict = Recorder_load_state_dict
971+
LogValidationReward.state_dict = Recorder_state_dict
972+
LogValidationReward.load_state_dict = Recorder_load_state_dict
974973

975974

976975
def test_updateweights():

torchrl/trainers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
ClearCudaCache,
99
CountFramesLog,
1010
LogReward,
11+
LogScalar,
12+
LogValidationReward,
1113
mask_batch,
1214
OptimizerHook,
1315
Recorder,

torchrl/trainers/helpers/trainers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
BatchSubSampler,
2626
ClearCudaCache,
2727
CountFramesLog,
28-
LogReward,
29-
Recorder,
28+
LogScalar,
29+
LogValidationReward,
3030
ReplayBufferTrainer,
3131
RewardNormalizer,
3232
SelectKeys,
@@ -259,7 +259,7 @@ def make_trainer(
259259

260260
if recorder is not None:
261261
# create recorder object
262-
recorder_obj = Recorder(
262+
recorder_obj = LogValidationReward(
263263
record_frames=cfg.record_frames,
264264
frame_skip=cfg.frame_skip,
265265
policy_exploration=policy_exploration,
@@ -275,7 +275,7 @@ def make_trainer(
275275
# call recorder - could be removed
276276
recorder_obj(None)
277277
# create explorative recorder - could be optional
278-
recorder_obj_explore = Recorder(
278+
recorder_obj_explore = LogValidationReward(
279279
record_frames=cfg.record_frames,
280280
frame_skip=cfg.frame_skip,
281281
policy_exploration=policy_exploration,
@@ -297,7 +297,7 @@ def make_trainer(
297297
"post_steps", UpdateWeights(collector, update_weights_interval=1)
298298
)
299299

300-
trainer.register_op("pre_steps_log", LogReward())
300+
trainer.register_op("pre_steps_log", LogScalar())
301301
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))
302302

303303
return trainer

torchrl/trainers/trainers.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs):
822822
torch.cuda.empty_cache()
823823

824824

825-
class LogReward(TrainerHookBase):
825+
class LogScalar(TrainerHookBase):
826826
"""Reward logger hook.
827827
828828
Args:
@@ -833,7 +833,7 @@ class LogReward(TrainerHookBase):
833833
in the input batch. Defaults to ``("next", "reward")``
834834
835835
Examples:
836-
>>> log_reward = LogReward(("next", "reward"))
836+
>>> log_reward = LogScalar(("next", "reward"))
837837
>>> trainer.register_op("pre_steps_log", log_reward)
838838
839839
"""
@@ -870,6 +870,23 @@ def register(self, trainer: Trainer, name: str = "log_reward"):
870870
trainer.register_module(name, self)
871871

872872

873+
class LogReward(LogScalar):
874+
"""Deprecated class. Use LogScalar instead."""
875+
876+
def __init__(
877+
self,
878+
logname="r_training",
879+
log_pbar: bool = False,
880+
reward_key: Union[str, tuple] = None,
881+
):
882+
warnings.warn(
883+
"The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.",
884+
DeprecationWarning,
885+
stacklevel=2,
886+
)
887+
super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key)
888+
889+
873890
class RewardNormalizer(TrainerHookBase):
874891
"""Reward normalizer hook.
875892
@@ -1127,7 +1144,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"):
11271144
trainer.register_module(name, self)
11281145

11291146

1130-
class Recorder(TrainerHookBase):
1147+
class LogValidationReward(TrainerHookBase):
11311148
"""Recorder hook for :class:`~torchrl.trainers.Trainer`.
11321149
11331150
Args:
@@ -1264,6 +1281,44 @@ def register(self, trainer: Trainer, name: str = "recorder"):
12641281
)
12651282

12661283

1284+
class Recorder(LogValidationReward):
1285+
"""Deprecated class. Use LogValidationReward instead."""
1286+
1287+
def __init__(
1288+
self,
1289+
*,
1290+
record_interval: int,
1291+
record_frames: int,
1292+
frame_skip: int = 1,
1293+
policy_exploration: TensorDictModule,
1294+
environment: EnvBase = None,
1295+
exploration_type: ExplorationType = ExplorationType.RANDOM,
1296+
log_keys: Optional[List[Union[str, Tuple[str]]]] = None,
1297+
out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None,
1298+
suffix: Optional[str] = None,
1299+
log_pbar: bool = False,
1300+
recorder: EnvBase = None,
1301+
) -> None:
1302+
warnings.warn(
1303+
"The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.",
1304+
DeprecationWarning,
1305+
stacklevel=2,
1306+
)
1307+
super().__init__(
1308+
record_interval=record_interval,
1309+
record_frames=record_frames,
1310+
frame_skip=frame_skip,
1311+
policy_exploration=policy_exploration,
1312+
environment=environment,
1313+
exploration_type=exploration_type,
1314+
log_keys=log_keys,
1315+
out_keys=out_keys,
1316+
suffix=suffix,
1317+
log_pbar=log_pbar,
1318+
recorder=recorder,
1319+
)
1320+
1321+
12671322
class UpdateWeights(TrainerHookBase):
12681323
"""A collector weights update hook class.
12691324

tutorials/sphinx-tutorials/coding_ddpg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -883,12 +883,12 @@ def make_ddpg_actor(
883883
#
884884
# As the training data is obtained using some exploration strategy, the true
885885
# performance of our algorithm needs to be assessed in deterministic mode. We
886-
# do this using a dedicated class, ``Recorder``, which executes the policy in
886+
# do this using a dedicated class, ``LogValidationReward``, which executes the policy in
887887
# the environment at a given frequency and returns some statistics obtained
888888
# from these simulations.
889889
#
890890
# The following helper function builds this object:
891-
from torchrl.trainers import Recorder
891+
from torchrl.trainers import LogValidationReward
892892

893893

894894
def make_recorder(actor_model_explore, transform_state_dict, record_interval):
@@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
899899
) # must be instantiated to load the state dict
900900
environment.transform[2].load_state_dict(transform_state_dict)
901901

902-
recorder_obj = Recorder(
902+
recorder_obj = LogValidationReward(
903903
record_frames=1000,
904904
policy_exploration=actor_model_explore,
905905
environment=environment,

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@
140140
from torchrl.objectives import DQNLoss, SoftUpdate
141141
from torchrl.record.loggers.csv import CSVLogger
142142
from torchrl.trainers import (
143-
LogReward,
144-
Recorder,
143+
LogScalar,
144+
LogValidationReward,
145145
ReplayBufferTrainer,
146146
Trainer,
147147
UpdateWeights,
@@ -666,7 +666,7 @@ def get_loss_module(actor, gamma):
666666
buffer_hook.register(trainer)
667667
weight_updater = UpdateWeights(collector, update_weights_interval=1)
668668
weight_updater.register(trainer)
669-
recorder = Recorder(
669+
recorder = LogValidationReward(
670670
record_interval=100, # log every 100 optimization steps
671671
record_frames=1000, # maximum number of frames in the record
672672
frame_skip=1,
@@ -704,7 +704,7 @@ def get_loss_module(actor, gamma):
704704
# This will be reflected by the `total_rewards` value displayed in the
705705
# progress bar.
706706
#
707-
log_reward = LogReward(log_pbar=True)
707+
log_reward = LogScalar(log_pbar=True)
708708
log_reward.register(trainer)
709709

710710
###############################################################################

0 commit comments

Comments
 (0)