Skip to content

Commit 8ab85c6

Browse files
authored
[Refactor] Move loggers to torchrl.record (#854)
1 parent 8efbb26 commit 8ab85c6

File tree

23 files changed

+36
-33
lines changed

23 files changed

+36
-33
lines changed

docs/source/reference/trainers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Utils
193193
Loggers
194194
-------
195195

196-
.. currentmodule:: torchrl.trainers.loggers
196+
.. currentmodule:: torchrl.recorder.loggers
197197

198198
.. autosummary::
199199
:toctree: generated/

examples/a2c/a2c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchrl.envs.transforms import RewardScaling
1212
from torchrl.envs.utils import set_exploration_mode
1313
from torchrl.objectives.value import TDEstimate
14+
from torchrl.record.loggers import generate_exp_name, get_logger
1415
from torchrl.trainers.helpers.collectors import (
1516
make_collector_onpolicy,
1617
OnPolicyCollectorConfig,
@@ -27,7 +28,6 @@
2728
from torchrl.trainers.helpers.losses import A2CLossConfig, make_a2c_loss
2829
from torchrl.trainers.helpers.models import A2CModelConfig, make_a2c_model
2930
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
30-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
3131

3232
config_fields = [
3333
(config_field.name, config_field.type, config_field)

examples/ddpg/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchrl.envs.utils import set_exploration_mode
1414
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
1515
from torchrl.record import VideoRecorder
16+
from torchrl.record.loggers import generate_exp_name, get_logger
1617
from torchrl.trainers.helpers.collectors import (
1718
make_collector_offpolicy,
1819
OffPolicyCollectorConfig,
@@ -30,7 +31,6 @@
3031
from torchrl.trainers.helpers.models import DDPGModelConfig, make_ddpg_actor
3132
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
3233
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
33-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
3434

3535
config_fields = [
3636
(config_field.name, config_field.type, config_field)

examples/dqn/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchrl.envs.transforms import RewardScaling, TransformedEnv
1313
from torchrl.modules import EGreedyWrapper
1414
from torchrl.record import VideoRecorder
15+
from torchrl.record.loggers import generate_exp_name, get_logger
1516
from torchrl.trainers.helpers.collectors import (
1617
make_collector_offpolicy,
1718
OffPolicyCollectorConfig,
@@ -29,7 +30,6 @@
2930
from torchrl.trainers.helpers.models import DiscreteModelConfig, make_dqn_actor
3031
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
3132
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
32-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
3333

3434
config_fields = [
3535
(config_field.name, config_field.type, config_field)

examples/dreamer/dreamer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DreamerModelLoss,
2828
DreamerValueLoss,
2929
)
30+
from torchrl.record.loggers import generate_exp_name, get_logger
3031
from torchrl.trainers.helpers.collectors import (
3132
make_collector_offpolicy,
3233
OffPolicyCollectorConfig,
@@ -40,7 +41,6 @@
4041
from torchrl.trainers.helpers.models import DreamerConfig, make_dreamer
4142
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
4243
from torchrl.trainers.helpers.trainers import TrainerConfig
43-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
4444
from torchrl.trainers.trainers import Recorder, RewardNormalizer
4545

4646
config_fields = [

examples/dreamer/dreamer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
TransformedEnv,
2525
)
2626
from torchrl.envs.transforms.transforms import FlattenObservation, TensorDictPrimer
27+
from torchrl.record.loggers import Logger
2728
from torchrl.record.recorder import VideoRecorder
28-
from torchrl.trainers.loggers import Logger
2929

3030
__all__ = [
3131
"transformed_env_constructor",

examples/ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchrl.envs.utils import set_exploration_mode
1414
from torchrl.objectives.value import GAE
1515
from torchrl.record import VideoRecorder
16+
from torchrl.record.loggers import generate_exp_name, get_logger
1617
from torchrl.trainers.helpers.collectors import (
1718
make_collector_onpolicy,
1819
OnPolicyCollectorConfig,
@@ -29,7 +30,6 @@
2930
from torchrl.trainers.helpers.losses import make_ppo_loss, PPOLossConfig
3031
from torchrl.trainers.helpers.models import make_ppo_model, PPOModelConfig
3132
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
32-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
3333

3434
config_fields = [
3535
(config_field.name, config_field.type, config_field)

examples/redq/redq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchrl.envs.utils import set_exploration_mode
1616
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
1717
from torchrl.record import VideoRecorder
18+
from torchrl.record.loggers import generate_exp_name, get_logger
1819
from torchrl.trainers.helpers.collectors import (
1920
make_collector_offpolicy,
2021
OffPolicyCollectorConfig,
@@ -32,7 +33,6 @@
3233
from torchrl.trainers.helpers.models import make_redq_model, REDQModelConfig
3334
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
3435
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
35-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
3636

3737
config_fields = [
3838
(config_field.name, config_field.type, config_field)

examples/sac/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchrl.envs.utils import set_exploration_mode
1616
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
1717
from torchrl.record import VideoRecorder
18+
from torchrl.record.loggers import generate_exp_name, get_logger
1819
from torchrl.trainers.helpers.collectors import (
1920
make_collector_offpolicy,
2021
OffPolicyCollectorConfig,
@@ -32,7 +33,6 @@
3233
from torchrl.trainers.helpers.models import make_sac_model, SACModelConfig
3334
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
3435
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig
35-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
3636

3737
config_fields = [
3838
(config_field.name, config_field.type, config_field)

examples/td3/td3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from torchrl.objectives import SoftUpdate
4040
from torchrl.objectives.td3 import TD3Loss
41-
from torchrl.trainers.loggers.utils import generate_exp_name, get_logger
41+
from torchrl.record.loggers import generate_exp_name, get_logger
4242

4343

4444
def env_maker(task, frame_skip=1, device="cpu", from_pixels=False):

test/test_loggers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212

1313
import pytest
1414
import torch
15-
from torchrl.trainers.loggers.csv import CSVLogger
16-
from torchrl.trainers.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
17-
from torchrl.trainers.loggers.tensorboard import _has_tb, TensorboardLogger
18-
from torchrl.trainers.loggers.wandb import _has_wandb, WandbLogger
15+
from torchrl.record.loggers import (
16+
CSVLogger,
17+
MLFlowLogger,
18+
TensorboardLogger,
19+
WandbLogger,
20+
)
21+
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv
22+
from torchrl.record.loggers.tensorboard import _has_tb
23+
from torchrl.record.loggers.wandb import _has_wandb
1924

2025
if _has_tv:
2126
import torchvision

test/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
try:
1818
from tensorboard.backend.event_processing import event_accumulator
19-
from torchrl.trainers.loggers.tensorboard import TensorboardLogger
19+
from torchrl.record.loggers import TensorboardLogger
2020

2121
_has_tb = True
2222
except ImportError:
File renamed without changes.
File renamed without changes.

torchrl/trainers/loggers/mlflow.py renamed to torchrl/record/loggers/mlflow.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import os
7-
import warnings
87
from tempfile import TemporaryDirectory
98
from typing import Any, Dict, Optional
109

@@ -18,22 +17,21 @@
1817

1918
from .common import Logger
2019

21-
_has_mlflow = False
20+
MLFLOW_ERR = None
2221
try:
2322
import mlflow
2423

2524
_has_mlflow = True
26-
except ImportError:
27-
warnings.warn("mlflow could not be imported")
28-
_has_omgaconf = False
25+
except ImportError as err:
26+
_has_mlflow = False
27+
MLFLOW_ERR = err
28+
2929
try:
3030
from omegaconf import OmegaConf
3131

3232
_has_omgaconf = True
3333
except ImportError:
34-
warnings.warn(
35-
"OmegaConf could not be imported. Cannot log hydra configs without OmegaConf"
36-
)
34+
_has_omgaconf = False
3735

3836

3937
class MLFlowLogger(Logger):
@@ -67,7 +65,7 @@ def _create_experiment(self) -> "mlflow.ActiveRun":
6765
mlflow.ActiveRun: The mlflow experiment object.
6866
"""
6967
if not _has_mlflow:
70-
raise ImportError("MLFlow is not installed")
68+
raise ImportError("MLFlow is not installed") from MLFLOW_ERR
7169
self.id = mlflow.create_experiment(**self._mlflow_kwargs)
7270
return mlflow.start_run(experiment_id=self.id)
7371

torchrl/trainers/loggers/utils.py renamed to torchrl/record/loggers/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import uuid
1010
from datetime import datetime
1111

12-
from torchrl.trainers.loggers.common import Logger
12+
from torchrl.record.loggers.common import Logger
1313

1414

1515
def generate_exp_name(model_name: str, experiment_name: str) -> str:
@@ -37,22 +37,22 @@ def get_logger(
3737
kwargs (dict[str]): might contain either `wandb_kwargs` or `mlflow_kwargs`
3838
"""
3939
if logger_type == "tensorboard":
40-
from torchrl.trainers.loggers.tensorboard import TensorboardLogger
40+
from torchrl.record.loggers.tensorboard import TensorboardLogger
4141

4242
logger = TensorboardLogger(log_dir=logger_name, exp_name=experiment_name)
4343
elif logger_type == "csv":
44-
from torchrl.trainers.loggers.csv import CSVLogger
44+
from torchrl.record.loggers.csv import CSVLogger
4545

4646
logger = CSVLogger(log_dir=logger_name, exp_name=experiment_name)
4747
elif logger_type == "wandb":
48-
from torchrl.trainers.loggers.wandb import WandbLogger
48+
from torchrl.record.loggers.wandb import WandbLogger
4949

5050
wandb_kwargs = kwargs.get("wandb_kwargs", {})
5151
logger = WandbLogger(
5252
log_dir=logger_name, exp_name=experiment_name, **wandb_kwargs
5353
)
5454
elif logger_type == "mlflow":
55-
from torchrl.trainers.loggers.mlflow import MLFlowLogger
55+
from torchrl.record.loggers.mlflow import MLFlowLogger
5656

5757
mlflow_kwargs = kwargs.get("mlflow_kwargs", {})
5858
logger = MLFlowLogger(
File renamed without changes.

torchrl/record/recorder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tensordict.tensordict import TensorDictBase
1717

1818
from torchrl.envs.transforms import ObservationTransform, Transform
19-
from torchrl.trainers.loggers import Logger
19+
from torchrl.record.loggers import Logger
2020

2121

2222
class VideoRecorder(ObservationTransform):

torchrl/trainers/helpers/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
VecNorm,
3131
)
3232
from torchrl.envs.transforms.transforms import FlattenObservation, gSDENoise
33+
from torchrl.record.loggers import Logger
3334
from torchrl.record.recorder import VideoRecorder
34-
from torchrl.trainers.loggers import Logger
3535

3636
LIBS = {
3737
"gym": GymEnv,

torchrl/trainers/helpers/trainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchrl.modules import reset_noise, SafeModule
1919
from torchrl.objectives.common import LossModule
2020
from torchrl.objectives.utils import TargetNetUpdater
21-
from torchrl.trainers.loggers import Logger
21+
from torchrl.record.loggers import Logger
2222
from torchrl.trainers.trainers import (
2323
BatchSubSampler,
2424
ClearCudaCache,

torchrl/trainers/trainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torchrl.envs.utils import set_exploration_mode
2828
from torchrl.modules import SafeModule
2929
from torchrl.objectives.common import LossModule
30-
from torchrl.trainers.loggers import Logger
30+
from torchrl.record.loggers import Logger
3131

3232
try:
3333
from tqdm import tqdm

0 commit comments

Comments
 (0)