Skip to content

Commit cd488b4

Browse files
sosmondvmoens
andauthored
[Feature] Remove wild imports in the library (#642)
* init * missing docstrings * Add explicit imports for subpackages * fix lint * Revert "fix lint" This reverts commit 040e156. * fix lint Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 49af5da commit cd488b4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1130
-886
lines changed

build_tools/setup_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .extension import * # noqa
6+
from .extension import get_ext_modules, CMakeBuild # noqa

build_tools/setup_helpers/extension.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
from setuptools import Extension
1515
from setuptools.command.build_ext import build_ext
1616

17-
__all__ = [
18-
"get_ext_modules",
19-
"CMakeBuild",
20-
]
2117

2218
_THIS_DIR = Path(__file__).parent.resolve()
2319
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()

torchrl/collectors/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,9 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .collectors import *
6+
from .collectors import (
7+
SyncDataCollector,
8+
aSyncDataCollector,
9+
MultiaSyncDataCollector,
10+
MultiSyncDataCollector,
11+
)

torchrl/collectors/collectors.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,16 @@
2020
from torch import multiprocessing as mp
2121
from torch.utils.data import IterableDataset
2222

23+
from torchrl.envs.transforms import TransformedEnv
2324
from torchrl.envs.utils import set_exploration_mode, step_mdp
2425
from .._utils import _check_for_faulty_process, prod
25-
from ..modules.tensordict_module import ProbabilisticTensorDictModule, TensorDictModule
26-
from .utils import split_trajectories
27-
28-
__all__ = [
29-
"SyncDataCollector",
30-
"aSyncDataCollector",
31-
"MultiaSyncDataCollector",
32-
"MultiSyncDataCollector",
33-
]
34-
35-
from torchrl.envs.transforms import TransformedEnv
3626
from ..data import TensorSpec
3727
from ..data.tensordict.tensordict import TensorDictBase, TensorDict
3828
from ..data.utils import CloudpickleWrapper, DEVICE_TYPING
3929
from ..envs.common import EnvBase
4030
from ..envs.vec_env import _BatchedEnv
31+
from ..modules.tensordict_module import ProbabilisticTensorDictModule, TensorDictModule
32+
from .utils import split_trajectories
4133

4234
_TIMEOUT = 1.0
4335
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
@@ -47,6 +39,8 @@
4739

4840

4941
class RandomPolicy:
42+
"""A random policy for data collectors."""
43+
5044
def __init__(self, action_spec: TensorSpec):
5145
"""Random policy for a given action_spec.
5246
@@ -73,6 +67,7 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase:
7367

7468

7569
def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
70+
"""Maps the tensors to CPU through a nested dictionary."""
7671
return OrderedDict(
7772
**{
7873
k: recursive_map_to_cpu(item)

torchrl/data/__init__.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,38 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .postprocs import *
7-
from .replay_buffers import *
8-
from .tensor_specs import *
9-
from .tensordict import *
6+
from .postprocs import MultiStep
7+
from .replay_buffers import (
8+
ReplayBuffer,
9+
PrioritizedReplayBuffer,
10+
TensorDictReplayBuffer,
11+
TensorDictPrioritizedReplayBuffer,
12+
Storage,
13+
ListStorage,
14+
LazyMemmapStorage,
15+
LazyTensorStorage,
16+
)
17+
from .tensor_specs import (
18+
TensorSpec,
19+
BoundedTensorSpec,
20+
OneHotDiscreteTensorSpec,
21+
UnboundedContinuousTensorSpec,
22+
UnboundedDiscreteTensorSpec,
23+
NdBoundedTensorSpec,
24+
NdUnboundedContinuousTensorSpec,
25+
NdUnboundedDiscreteTensorSpec,
26+
BinaryDiscreteTensorSpec,
27+
MultOneHotDiscreteTensorSpec,
28+
DiscreteTensorSpec,
29+
CompositeSpec,
30+
DEVICE_TYPING,
31+
)
32+
from .tensordict import (
33+
MemmapTensor,
34+
set_transfer_ownership,
35+
TensorDict,
36+
SubTensorDict,
37+
merge_tensordicts,
38+
LazyStackedTensorDict,
39+
SavedTensorDict,
40+
)

torchrl/data/postprocs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .postprocs import *
6+
from .postprocs import MultiStep

torchrl/data/postprocs/postprocs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from torchrl.data.tensordict.tensordict import TensorDictBase
1515
from torchrl.data.utils import expand_as_right
1616

17-
__all__ = ["MultiStep"]
18-
1917

2018
def _conv1d_reward(
2119
reward: torch.Tensor, gammas: torch.Tensor, n_steps_max: int
@@ -76,7 +74,7 @@ def _get_steps_to_next_obs(nonterminal: torch.Tensor, n_steps_max: int) -> torch
7674
return steps_to_next_obs
7775

7876

79-
def select_and_repeat(
77+
def _select_and_repeat(
8078
tensor: torch.Tensor,
8179
terminal: torch.Tensor,
8280
post_terminal: torch.Tensor,
@@ -203,7 +201,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
203201
for key, item in selected_td.items():
204202
tensordict.set_(
205203
key,
206-
select_and_repeat(
204+
_select_and_repeat(
207205
item,
208206
terminal,
209207
post_terminal,

torchrl/data/replay_buffers/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .replay_buffers import *
7-
from .storages import *
6+
from .replay_buffers import (
7+
ReplayBuffer,
8+
PrioritizedReplayBuffer,
9+
TensorDictReplayBuffer,
10+
TensorDictPrioritizedReplayBuffer,
11+
)
12+
from .storages import Storage, ListStorage, LazyMemmapStorage, LazyTensorStorage

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
)
3333
from torchrl.data.utils import DEVICE_TYPING
3434

35-
__all__ = [
36-
"ReplayBuffer",
37-
"PrioritizedReplayBuffer",
38-
"TensorDictReplayBuffer",
39-
"TensorDictPrioritizedReplayBuffer",
40-
]
41-
4235

4336
def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:
4437
"""Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together.
@@ -730,6 +723,13 @@ def sample(self, size: int, return_weight: bool = False) -> TensorDictBase:
730723

731724

732725
class InPlaceSampler:
726+
"""A sampler to write tennsordicts in-place.
727+
728+
To be used cautiously as this may lead to unexpected behaviour (i.e. tensordicts
729+
overwritten during execution).
730+
731+
"""
732+
733733
def __init__(self, device: Optional[DEVICE_TYPING] = None):
734734
self.out = None
735735
if device is None:

torchrl/data/replay_buffers/storages.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
except ImportError:
2424
_has_ts = False
2525

26-
__all__ = ["Storage", "ListStorage", "LazyMemmapStorage", "LazyTensorStorage"]
27-
2826

2927
class Storage:
3028
"""A Storage is the container of a replay buffer.

0 commit comments

Comments
 (0)