Skip to content

Commit 2defe3e

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 3bc83f9 + 2cdb5a3 commit 2defe3e

File tree

29 files changed

+642
-14
lines changed

29 files changed

+642
-14
lines changed

build_tools/setup_helpers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .extension import CMakeBuild, get_ext_modules # noqa
7+
8+
__all__ = ["CMakeBuild", "get_ext_modules"]

test/test_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@
6060
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
6161
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
6262
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
63+
from torchrl.envs.transforms.rlhf import as_padded_tensor
6364
from torchrl.envs.transforms.transforms import (
6465
AutoResetEnv,
6566
AutoResetTransform,
6667
Tokenizer,
6768
Transform,
6869
UnsqueezeTransform,
6970
)
70-
from torchrl.envs.transforms.rlhf import as_padded_tensor
7171
from torchrl.envs.utils import (
7272
_StepMDP,
7373
_terminated_or_truncated,

torchrl/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@
4646
)
4747

4848

49-
import torchrl.collectors
50-
import torchrl.data
51-
import torchrl.envs
52-
import torchrl.modules
53-
import torchrl.objectives
54-
import torchrl.trainers
5549
from torchrl._utils import (
5650
auto_unwrap_transformed_env,
5751
compile_with_warmup,
@@ -107,3 +101,11 @@ def _inv(self):
107101

108102

109103
ComposeTransform.inv = _inv
104+
105+
__all__ = [
106+
"auto_unwrap_transformed_env",
107+
"compile_with_warmup",
108+
"implement_for",
109+
"set_auto_unwrap_transformed_env",
110+
"timeit",
111+
]

torchrl/collectors/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,12 @@
1212
MultiSyncDataCollector,
1313
SyncDataCollector,
1414
)
15+
16+
__all__ = [
17+
"RandomPolicy",
18+
"aSyncDataCollector",
19+
"DataCollectorBase",
20+
"MultiaSyncDataCollector",
21+
"MultiSyncDataCollector",
22+
"SyncDataCollector",
23+
]

torchrl/collectors/distributed/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,12 @@
88
from .rpc import RPCDataCollector
99
from .sync import DistributedSyncDataCollector
1010
from .utils import submitit_delayed_launcher
11+
12+
__all__ = [
13+
"DEFAULT_SLURM_CONF",
14+
"DistributedDataCollector",
15+
"RayCollector",
16+
"RPCDataCollector",
17+
"DistributedSyncDataCollector",
18+
"submitit_delayed_launcher",
19+
]

torchrl/data/map/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,15 @@
77
from .query import HashToInt, QueryModule
88
from .tdstorage import TensorDictMap, TensorMap
99
from .tree import MCTSForest, Tree
10+
11+
__all__ = [
12+
"BinaryToDecimal",
13+
"RandomProjectionHash",
14+
"SipHash",
15+
"HashToInt",
16+
"QueryModule",
17+
"TensorDictMap",
18+
"TensorMap",
19+
"MCTSForest",
20+
"Tree",
21+
]

torchrl/data/postprocs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .postprocs import MultiStep
7+
8+
__all__ = ["MultiStep"]

torchrl/data/replay_buffers/__init__.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,46 @@
4848
Writer,
4949
WriterEnsemble,
5050
)
51+
52+
__all__ = [
53+
"FlatStorageCheckpointer",
54+
"H5StorageCheckpointer",
55+
"ListStorageCheckpointer",
56+
"NestedStorageCheckpointer",
57+
"StorageCheckpointerBase",
58+
"StorageEnsembleCheckpointer",
59+
"TensorStorageCheckpointer",
60+
"PrioritizedReplayBuffer",
61+
"RemoteTensorDictReplayBuffer",
62+
"ReplayBuffer",
63+
"ReplayBufferEnsemble",
64+
"TensorDictPrioritizedReplayBuffer",
65+
"TensorDictReplayBuffer",
66+
"PrioritizedSampler",
67+
"PrioritizedSliceSampler",
68+
"RandomSampler",
69+
"Sampler",
70+
"SamplerEnsemble",
71+
"SamplerWithoutReplacement",
72+
"SliceSampler",
73+
"SliceSamplerWithoutReplacement",
74+
"LazyMemmapStorage",
75+
"LazyStackStorage",
76+
"LazyTensorStorage",
77+
"ListStorage",
78+
"Storage",
79+
"StorageEnsemble",
80+
"TensorStorage",
81+
"Flat2TED",
82+
"H5Combine",
83+
"H5Split",
84+
"Nested2TED",
85+
"TED2Flat",
86+
"TED2Nested",
87+
"ImmutableDatasetWriter",
88+
"RoundRobinWriter",
89+
"TensorDictMaxValueWriter",
90+
"TensorDictRoundRobinWriter",
91+
"Writer",
92+
"WriterEnsemble",
93+
]

torchrl/data/rlhf/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,17 @@
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
1414
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
15+
16+
__all__ = [
17+
"create_infinite_iterator",
18+
"get_dataloader",
19+
"TensorDictTokenizer",
20+
"TokenizedDatasetLoader",
21+
"PromptData",
22+
"PromptTensorDictTokenizer",
23+
"PairwiseDataset",
24+
"RewardData",
25+
"AdaptiveKLController",
26+
"ConstantKLController",
27+
"RolloutFromModel",
28+
]

torchrl/envs/__init__.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,129 @@
121121
set_exploration_type,
122122
step_mdp,
123123
)
124+
125+
__all__ = [
126+
"ParallelEnv",
127+
"SerialEnv",
128+
"EnvBase",
129+
"EnvMetaData",
130+
"make_tensordict",
131+
"ChessEnv",
132+
"LLMHashingEnv",
133+
"PendulumEnv",
134+
"TicTacToeEnv",
135+
"env_creator",
136+
"EnvCreator",
137+
"get_env_metadata",
138+
"default_info_dict_reader",
139+
"GymLikeEnv",
140+
"BraxEnv",
141+
"BraxWrapper",
142+
"DMControlEnv",
143+
"DMControlWrapper",
144+
"gym_backend",
145+
"GymEnv",
146+
"GymWrapper",
147+
"HabitatEnv",
148+
"IsaacGymEnv",
149+
"IsaacGymWrapper",
150+
"JumanjiEnv",
151+
"JumanjiWrapper",
152+
"MeltingpotEnv",
153+
"MeltingpotWrapper",
154+
"MOGymEnv",
155+
"MOGymWrapper",
156+
"MultiThreadedEnv",
157+
"MultiThreadedEnvWrapper",
158+
"OpenMLEnv",
159+
"OpenSpielEnv",
160+
"OpenSpielWrapper",
161+
"PettingZooEnv",
162+
"PettingZooWrapper",
163+
"register_gym_spec_conversion",
164+
"RoboHiveEnv",
165+
"set_gym_backend",
166+
"SMACv2Env",
167+
"SMACv2Wrapper",
168+
"UnityMLAgentsEnv",
169+
"UnityMLAgentsWrapper",
170+
"VmasEnv",
171+
"VmasWrapper",
172+
"DreamerDecoder",
173+
"DreamerEnv",
174+
"ModelBasedEnvBase",
175+
"ActionDiscretizer",
176+
"ActionMask",
177+
"AutoResetEnv",
178+
"AutoResetTransform",
179+
"BatchSizeTransform",
180+
"BinarizeReward",
181+
"BurnInTransform",
182+
"CatFrames",
183+
"CatTensors",
184+
"CenterCrop",
185+
"ClipTransform",
186+
"Compose",
187+
"ConditionalSkip",
188+
"Crop",
189+
"DeviceCastTransform",
190+
"DiscreteActionProjection",
191+
"DoubleToFloat",
192+
"DTypeCastTransform",
193+
"EndOfLifeTransform",
194+
"ExcludeTransform",
195+
"FiniteTensorDictCheck",
196+
"FlattenObservation",
197+
"FrameSkipTransform",
198+
"GrayScale",
199+
"gSDENoise",
200+
"Hash",
201+
"InitTracker",
202+
"KLRewardTransform",
203+
"LineariseRewards",
204+
"MultiAction",
205+
"MultiStepTransform",
206+
"NoopResetEnv",
207+
"ObservationNorm",
208+
"ObservationTransform",
209+
"PermuteTransform",
210+
"PinMemoryTransform",
211+
"R3MTransform",
212+
"RandomCropTensorDict",
213+
"RemoveEmptySpecs",
214+
"RenameTransform",
215+
"Resize",
216+
"Reward2GoTransform",
217+
"RewardClipping",
218+
"RewardScaling",
219+
"RewardSum",
220+
"SelectTransform",
221+
"SignTransform",
222+
"SqueezeTransform",
223+
"Stack",
224+
"StepCounter",
225+
"TargetReturn",
226+
"TensorDictPrimer",
227+
"TimeMaxPool",
228+
"Timer",
229+
"Tokenizer",
230+
"ToTensorImage",
231+
"TrajCounter",
232+
"Transform",
233+
"TransformedEnv",
234+
"UnaryTransform",
235+
"UnsqueezeTransform",
236+
"VC1Transform",
237+
"VecGymEnvTransform",
238+
"VecNorm",
239+
"VIPRewardTransform",
240+
"VIPTransform",
241+
"check_env_specs",
242+
"check_marl_grouping",
243+
"exploration_type",
244+
"ExplorationType",
245+
"make_composite_from_td",
246+
"MarlGroupMapType",
247+
"set_exploration_type",
248+
"step_mdp",
249+
]

0 commit comments

Comments
 (0)