Skip to content

Commit ee58306

Browse files
authored
[Refactor] A less verbose torchrl (#973)
1 parent bea2474 commit ee58306

File tree

10 files changed

+58
-40
lines changed

10 files changed

+58
-40
lines changed

torchrl/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import collections
2+
23
import math
34
import os
45
import time
56
import warnings
7+
from distutils.util import strtobool
68
from functools import wraps
79
from importlib import import_module
810

911
import numpy as np
1012

13+
VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
14+
1115

1216
class timeit:
1317
"""A dirty but easy to use decorator for profiling code."""

torchrl/collectors/collectors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch import multiprocessing as mp
2424
from torch.utils.data import IterableDataset
2525

26-
from torchrl._utils import _check_for_faulty_process, prod
26+
from torchrl._utils import _check_for_faulty_process, prod, VERBOSE
2727
from torchrl.collectors.utils import split_trajectories
2828
from torchrl.data import TensorSpec
2929
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
@@ -1662,7 +1662,7 @@ def _main_async_collector(
16621662
idx: int = 0,
16631663
exploration_mode: str = DEFAULT_EXPLORATION_MODE,
16641664
reset_when_done: bool = True,
1665-
verbose: bool = False,
1665+
verbose: bool = VERBOSE,
16661666
) -> None:
16671667
pipe_parent.close()
16681668
#  init variables that will be cleared when closing

torchrl/data/replay_buffers/storages.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tensordict.prototype import is_tensorclass
1616
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
1717

18-
from torchrl._utils import _CKPT_BACKEND
18+
from torchrl._utils import _CKPT_BACKEND, VERBOSE
1919
from torchrl.data.replay_buffers.utils import INT_CLASSES
2020

2121
try:
@@ -224,7 +224,8 @@ def load_state_dict(self, state_dict):
224224
self._len = state_dict["_len"]
225225

226226
def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
227-
print("Creating a TensorStorage...")
227+
if VERBOSE:
228+
print("Creating a TensorStorage...")
228229
if isinstance(data, torch.Tensor):
229230
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
230231
out = torch.empty(
@@ -353,16 +354,18 @@ def load_state_dict(self, state_dict):
353354
self._len = state_dict["_len"]
354355

355356
def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
356-
print("Creating a MemmapStorage...")
357+
if VERBOSE:
358+
print("Creating a MemmapStorage...")
357359
if isinstance(data, torch.Tensor):
358360
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
359361
out = MemmapTensor(
360362
self.max_size, *data.shape, device=self.device, dtype=data.dtype
361363
)
362364
filesize = os.path.getsize(out.filename) / 1024 / 1024
363-
print(
364-
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
365-
)
365+
if VERBOSE:
366+
print(
367+
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
368+
)
366369
elif is_tensorclass(data):
367370
out = (
368371
data.clone()
@@ -374,12 +377,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
374377
out.items(include_nested=True, leaves_only=True), key=str
375378
):
376379
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
377-
print(
378-
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
379-
)
380+
if VERBOSE:
381+
print(
382+
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
383+
)
380384
else:
381-
# out = TensorDict({}, [self.max_size, *data.shape])
382-
print("The storage is being created: ")
385+
if VERBOSE:
386+
print("The storage is being created: ")
383387
out = (
384388
data.clone()
385389
.expand(self.max_size, *data.shape)
@@ -390,9 +394,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
390394
out.items(include_nested=True, leaves_only=True), key=str
391395
):
392396
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
393-
print(
394-
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
395-
)
397+
if VERBOSE:
398+
print(
399+
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
400+
)
396401
self._storage = out
397402
self.initialized = True
398403

torchrl/envs/libs/dm_control.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
UnboundedDiscreteTensorSpec,
2020
)
2121

22+
from ..._utils import VERBOSE
23+
2224
from ...data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict
2325
from ..gym_like import GymLikeEnv
2426

2527
if torch.has_cuda and torch.cuda.device_count() > 1:
2628
n = torch.cuda.device_count() - 1
2729
os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n))
28-
print("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"])
30+
if VERBOSE:
31+
print("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"])
2932

3033
try:
3134

torchrl/envs/vec_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensordict import TensorDict
2222
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
2323
from torch import multiprocessing as mp
24-
from torchrl._utils import _check_for_faulty_process
24+
from torchrl._utils import _check_for_faulty_process, VERBOSE
2525
from torchrl.data import (
2626
CompositeSpec,
2727
DiscreteTensorSpec,
@@ -126,7 +126,7 @@ class _BatchedEnv(EnvBase):
126126
127127
"""
128128

129-
_verbose: bool = False
129+
_verbose: bool = VERBOSE
130130
_excluded_wrapped_keys = [
131131
"is_closed",
132132
"parent_channels",

torchrl/modules/models/exploration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
import math
6+
import warnings
77
from typing import Optional, Sequence, Union
88

99
import torch
@@ -390,7 +390,7 @@ def forward(self, mu, state, _eps_gSDE):
390390

391391
sigma = (sigma * state.unsqueeze(-2)).pow(2).sum(-1).clamp_min(1e-5).sqrt()
392392
if not torch.isfinite(sigma).all():
393-
print("inf sigma")
393+
warnings.warn("inf sigma")
394394

395395
if self.transform is not None:
396396
action = self.transform(action)

torchrl/modules/tensordict_module/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
_has_functorch = True
2222
except ImportError:
23-
print(
23+
warnings.warn(
2424
"failed to import functorch. TorchRL's features that do not require "
2525
"functional programming should work, but functionality and performance "
2626
"may be affected. Consider installing functorch and/or upgrating pytorch."

torchrl/trainers/helpers/envs.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
from torchrl._utils import VERBOSE
1112
from torchrl.envs import ParallelEnv
1213
from torchrl.envs.common import EnvBase
1314
from torchrl.envs.env_creator import env_creator, EnvCreator
@@ -394,7 +395,8 @@ def get_stats_random_rollout(
394395
cfg=cfg, use_env_creator=False, stats={"loc": 0.0, "scale": 1.0}
395396
)()
396397

397-
print("computing state stats")
398+
if VERBOSE:
399+
print("computing state stats")
398400
if not hasattr(cfg, "init_env_steps"):
399401
raise AttributeError("init_env_steps missing from arguments.")
400402

@@ -426,11 +428,12 @@ def get_stats_random_rollout(
426428
m[s == 0] = 0.0
427429
s[s == 0] = 1.0
428430

429-
print(
430-
f"stats computed for {val_stats.numel()} steps. Got: \n"
431-
f"loc = {m}, \n"
432-
f"scale = {s}"
433-
)
431+
if VERBOSE:
432+
print(
433+
f"stats computed for {val_stats.numel()} steps. Got: \n"
434+
f"loc = {m}, \n"
435+
f"scale = {s}"
436+
)
434437
if not torch.isfinite(m).all():
435438
raise RuntimeError("non-finite values found in mean")
436439
if not torch.isfinite(s).all():

torchrl/trainers/helpers/trainers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import optim
1313
from torch.optim.lr_scheduler import CosineAnnealingLR
1414

15+
from torchrl._utils import VERBOSE
1516
from torchrl.collectors.collectors import _DataCollector
1617
from torchrl.data import ReplayBuffer
1718
from torchrl.envs.common import EnvBase
@@ -168,16 +169,17 @@ def make_trainer(
168169
else:
169170
raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}")
170171

171-
print(
172-
f"collector = {collector}; \n"
173-
f"loss_module = {loss_module}; \n"
174-
f"recorder = {recorder}; \n"
175-
f"target_net_updater = {target_net_updater}; \n"
176-
f"policy_exploration = {policy_exploration}; \n"
177-
f"replay_buffer = {replay_buffer}; \n"
178-
f"logger = {logger}; \n"
179-
f"cfg = {cfg}; \n"
180-
)
172+
if VERBOSE:
173+
print(
174+
f"collector = {collector}; \n"
175+
f"loss_module = {loss_module}; \n"
176+
f"recorder = {recorder}; \n"
177+
f"target_net_updater = {target_net_updater}; \n"
178+
f"policy_exploration = {policy_exploration}; \n"
179+
f"replay_buffer = {replay_buffer}; \n"
180+
f"logger = {logger}; \n"
181+
f"cfg = {cfg}; \n"
182+
)
181183

182184
if logger is not None:
183185
# log hyperparams

torchrl/trainers/trainers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tensordict.utils import expand_right
2020
from torch import nn, optim
2121

22-
from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict
22+
from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict, VERBOSE
2323
from torchrl.collectors.collectors import _DataCollector
2424
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
2525
from torchrl.data.utils import DEVICE_TYPING
@@ -451,7 +451,8 @@ def __del__(self):
451451
self.collector.shutdown()
452452

453453
def shutdown(self):
454-
print("shutting down collector")
454+
if VERBOSE:
455+
print("shutting down collector")
455456
self.collector.shutdown()
456457

457458
def optim_steps(self, batch: TensorDictBase) -> None:

0 commit comments

Comments
 (0)