Skip to content

Commit 0df5d7b

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent f6d00ca commit 0df5d7b

File tree

26 files changed

+602
-4
lines changed

26 files changed

+602
-4
lines changed

torchrl/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@
4646
)
4747

4848

49+
import torchrl.collectors
50+
import torchrl.data
51+
import torchrl.envs
52+
import torchrl.modules
53+
import torchrl.objectives
54+
import torchrl.trainers
55+
from torchrl._utils import (
56+
auto_unwrap_transformed_env,
57+
compile_with_warmup,
58+
implement_for,
59+
set_auto_unwrap_transformed_env,
60+
timeit,
61+
)
62+
4963
# Filter warnings in subprocesses: True by default given the multiple optional
5064
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
5165
filter_warnings_subprocess = True

torchrl/collectors/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,13 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from torchrl.envs.utils import RandomPolicy
7+
8+
from .collectors import (
9+
aSyncDataCollector,
10+
DataCollectorBase,
11+
MultiaSyncDataCollector,
12+
MultiSyncDataCollector,
13+
SyncDataCollector,
14+
)

torchrl/collectors/distributed/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .generic import DEFAULT_SLURM_CONF, DistributedDataCollector
7+
from .ray import RayCollector
8+
from .rpc import RPCDataCollector
9+
from .sync import DistributedSyncDataCollector
10+
from .utils import submitit_delayed_launcher

torchrl/data/map/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
7+
from .query import HashToInt, QueryModule
8+
from .tdstorage import TensorDictMap, TensorMap
9+
from .tree import MCTSForest, Tree

torchrl/data/postprocs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .postprocs import MultiStep

torchrl/data/replay_buffers/__init__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,49 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .checkpointers import (
7+
FlatStorageCheckpointer,
8+
H5StorageCheckpointer,
9+
ListStorageCheckpointer,
10+
NestedStorageCheckpointer,
11+
StorageCheckpointerBase,
12+
StorageEnsembleCheckpointer,
13+
TensorStorageCheckpointer,
14+
)
15+
from .replay_buffers import (
16+
PrioritizedReplayBuffer,
17+
RemoteTensorDictReplayBuffer,
18+
ReplayBuffer,
19+
ReplayBufferEnsemble,
20+
TensorDictPrioritizedReplayBuffer,
21+
TensorDictReplayBuffer,
22+
)
23+
from .samplers import (
24+
PrioritizedSampler,
25+
PrioritizedSliceSampler,
26+
RandomSampler,
27+
Sampler,
28+
SamplerEnsemble,
29+
SamplerWithoutReplacement,
30+
SliceSampler,
31+
SliceSamplerWithoutReplacement,
32+
)
33+
from .storages import (
34+
LazyMemmapStorage,
35+
LazyStackStorage,
36+
LazyTensorStorage,
37+
ListStorage,
38+
Storage,
39+
StorageEnsemble,
40+
TensorStorage,
41+
)
42+
from .utils import Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested
43+
from .writers import (
44+
ImmutableDatasetWriter,
45+
RoundRobinWriter,
46+
TensorDictMaxValueWriter,
47+
TensorDictRoundRobinWriter,
48+
Writer,
49+
WriterEnsemble,
50+
)

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def __init__(
416416
max_size: int | None = None,
417417
*,
418418
compilable: bool = False,
419-
stack_dim: int = -1,
419+
stack_dim: int = 0,
420420
):
421421
super().__init__(max_size=max_size, compilable=compilable)
422422
self.stack_dim = stack_dim

torchrl/data/rlhf/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,13 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .dataset import (
7+
create_infinite_iterator,
8+
get_dataloader,
9+
TensorDictTokenizer,
10+
TokenizedDatasetLoader,
11+
)
12+
from .prompt import PromptData, PromptTensorDictTokenizer
13+
from .reward import PairwiseDataset, RewardData
14+
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel

torchrl/envs/__init__.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,122 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .batched_envs import ParallelEnv, SerialEnv
7+
from .common import EnvBase, EnvMetaData, make_tensordict
8+
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
9+
from .env_creator import env_creator, EnvCreator, get_env_metadata
10+
from .gym_like import default_info_dict_reader, GymLikeEnv
11+
from .libs import (
12+
BraxEnv,
13+
BraxWrapper,
14+
DMControlEnv,
15+
DMControlWrapper,
16+
gym_backend,
17+
GymEnv,
18+
GymWrapper,
19+
HabitatEnv,
20+
IsaacGymEnv,
21+
IsaacGymWrapper,
22+
JumanjiEnv,
23+
JumanjiWrapper,
24+
MeltingpotEnv,
25+
MeltingpotWrapper,
26+
MOGymEnv,
27+
MOGymWrapper,
28+
MultiThreadedEnv,
29+
MultiThreadedEnvWrapper,
30+
OpenMLEnv,
31+
OpenSpielEnv,
32+
OpenSpielWrapper,
33+
PettingZooEnv,
34+
PettingZooWrapper,
35+
register_gym_spec_conversion,
36+
RoboHiveEnv,
37+
set_gym_backend,
38+
SMACv2Env,
39+
SMACv2Wrapper,
40+
UnityMLAgentsEnv,
41+
UnityMLAgentsWrapper,
42+
VmasEnv,
43+
VmasWrapper,
44+
)
45+
from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase
46+
from .transforms import (
47+
ActionDiscretizer,
48+
ActionMask,
49+
AutoResetEnv,
50+
AutoResetTransform,
51+
BatchSizeTransform,
52+
BinarizeReward,
53+
BurnInTransform,
54+
CatFrames,
55+
CatTensors,
56+
CenterCrop,
57+
ClipTransform,
58+
Compose,
59+
ConditionalSkip,
60+
Crop,
61+
DeviceCastTransform,
62+
DiscreteActionProjection,
63+
DoubleToFloat,
64+
DTypeCastTransform,
65+
EndOfLifeTransform,
66+
ExcludeTransform,
67+
FiniteTensorDictCheck,
68+
FlattenObservation,
69+
FrameSkipTransform,
70+
GrayScale,
71+
gSDENoise,
72+
Hash,
73+
InitTracker,
74+
KLRewardTransform,
75+
LineariseRewards,
76+
MultiAction,
77+
MultiStepTransform,
78+
NoopResetEnv,
79+
ObservationNorm,
80+
ObservationTransform,
81+
PermuteTransform,
82+
PinMemoryTransform,
83+
R3MTransform,
84+
RandomCropTensorDict,
85+
RemoveEmptySpecs,
86+
RenameTransform,
87+
Resize,
88+
Reward2GoTransform,
89+
RewardClipping,
90+
RewardScaling,
91+
RewardSum,
92+
SelectTransform,
93+
SignTransform,
94+
SqueezeTransform,
95+
Stack,
96+
StepCounter,
97+
TargetReturn,
98+
TensorDictPrimer,
99+
TimeMaxPool,
100+
Timer,
101+
Tokenizer,
102+
ToTensorImage,
103+
TrajCounter,
104+
Transform,
105+
TransformedEnv,
106+
UnaryTransform,
107+
UnsqueezeTransform,
108+
VC1Transform,
109+
VecGymEnvTransform,
110+
VecNorm,
111+
VIPRewardTransform,
112+
VIPTransform,
113+
)
114+
from .utils import (
115+
check_env_specs,
116+
check_marl_grouping,
117+
exploration_type,
118+
ExplorationType,
119+
make_composite_from_td,
120+
MarlGroupMapType,
121+
set_exploration_type,
122+
step_mdp,
123+
)

torchrl/envs/custom/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .chess import ChessEnv
7+
from .llm import LLMHashingEnv
8+
from .pendulum import PendulumEnv
9+
from .tictactoeenv import TicTacToeEnv

0 commit comments

Comments
 (0)