Skip to content

Commit d7a6812

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 00b2088 commit d7a6812

File tree

6 files changed

+207
-6
lines changed

6 files changed

+207
-6
lines changed

torchrl/data/__init__.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,197 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from .map import (
7+
BinaryToDecimal,
8+
HashToInt,
9+
MCTSForest,
10+
QueryModule,
11+
RandomProjectionHash,
12+
SipHash,
13+
TensorDictMap,
14+
TensorMap,
15+
Tree,
16+
)
17+
from .postprocs import MultiStep
18+
from .replay_buffers import (
19+
Flat2TED,
20+
FlatStorageCheckpointer,
21+
H5Combine,
22+
H5Split,
23+
H5StorageCheckpointer,
24+
ImmutableDatasetWriter,
25+
LazyMemmapStorage,
26+
LazyStackStorage,
27+
LazyTensorStorage,
28+
ListStorage,
29+
ListStorageCheckpointer,
30+
Nested2TED,
31+
NestedStorageCheckpointer,
32+
PrioritizedReplayBuffer,
33+
PrioritizedSampler,
34+
PrioritizedSliceSampler,
35+
RandomSampler,
36+
RemoteTensorDictReplayBuffer,
37+
ReplayBuffer,
38+
ReplayBufferEnsemble,
39+
RoundRobinWriter,
40+
SamplerEnsemble,
41+
SamplerWithoutReplacement,
42+
SliceSampler,
43+
SliceSamplerWithoutReplacement,
44+
Storage,
45+
StorageCheckpointerBase,
46+
StorageEnsemble,
47+
StorageEnsembleCheckpointer,
48+
TED2Flat,
49+
TED2Nested,
50+
TensorDictMaxValueWriter,
51+
TensorDictPrioritizedReplayBuffer,
52+
TensorDictReplayBuffer,
53+
TensorDictRoundRobinWriter,
54+
TensorStorage,
55+
TensorStorageCheckpointer,
56+
Writer,
57+
WriterEnsemble,
58+
)
59+
from .rlhf import (
60+
AdaptiveKLController,
61+
ConstantKLController,
62+
create_infinite_iterator,
63+
get_dataloader,
64+
PairwiseDataset,
65+
PromptData,
66+
PromptTensorDictTokenizer,
67+
RewardData,
68+
RolloutFromModel,
69+
TensorDictTokenizer,
70+
TokenizedDatasetLoader,
71+
)
72+
from .tensor_specs import (
73+
Binary,
74+
BinaryDiscreteTensorSpec,
75+
Bounded,
76+
BoundedContinuous,
77+
BoundedTensorSpec,
78+
Categorical,
79+
Choice,
80+
Composite,
81+
CompositeSpec,
82+
DEVICE_TYPING,
83+
DiscreteTensorSpec,
84+
LazyStackedCompositeSpec,
85+
LazyStackedTensorSpec,
86+
MultiCategorical,
87+
MultiDiscreteTensorSpec,
88+
MultiOneHot,
89+
MultiOneHotDiscreteTensorSpec,
90+
NonTensor,
91+
NonTensorSpec,
92+
OneHot,
93+
OneHotDiscreteTensorSpec,
94+
Stacked,
95+
StackedComposite,
96+
TensorSpec,
97+
Unbounded,
98+
UnboundedContinuous,
99+
UnboundedContinuousTensorSpec,
100+
UnboundedDiscrete,
101+
UnboundedDiscreteTensorSpec,
102+
)
103+
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
104+
105+
__all__ = [
106+
"BinaryToDecimal",
107+
"HashToInt",
108+
"MCTSForest",
109+
"QueryModule",
110+
"RandomProjectionHash",
111+
"SipHash",
112+
"TensorDictMap",
113+
"TensorMap",
114+
"Tree",
115+
"MultiStep",
116+
"Flat2TED",
117+
"FlatStorageCheckpointer",
118+
"H5Combine",
119+
"H5Split",
120+
"H5StorageCheckpointer",
121+
"ImmutableDatasetWriter",
122+
"LazyMemmapStorage",
123+
"LazyStackStorage",
124+
"LazyTensorStorage",
125+
"ListStorage",
126+
"ListStorageCheckpointer",
127+
"Nested2TED",
128+
"NestedStorageCheckpointer",
129+
"PrioritizedReplayBuffer",
130+
"PrioritizedSampler",
131+
"PrioritizedSliceSampler",
132+
"RandomSampler",
133+
"RemoteTensorDictReplayBuffer",
134+
"ReplayBuffer",
135+
"ReplayBufferEnsemble",
136+
"RoundRobinWriter",
137+
"SamplerEnsemble",
138+
"SamplerWithoutReplacement",
139+
"SliceSampler",
140+
"SliceSamplerWithoutReplacement",
141+
"Storage",
142+
"StorageCheckpointerBase",
143+
"StorageEnsemble",
144+
"StorageEnsembleCheckpointer",
145+
"TED2Flat",
146+
"TED2Nested",
147+
"TensorDictMaxValueWriter",
148+
"TensorDictPrioritizedReplayBuffer",
149+
"TensorDictReplayBuffer",
150+
"TensorDictRoundRobinWriter",
151+
"TensorStorage",
152+
"TensorStorageCheckpointer",
153+
"Writer",
154+
"WriterEnsemble",
155+
"AdaptiveKLController",
156+
"ConstantKLController",
157+
"create_infinite_iterator",
158+
"get_dataloader",
159+
"PairwiseDataset",
160+
"PromptData",
161+
"PromptTensorDictTokenizer",
162+
"RewardData",
163+
"RolloutFromModel",
164+
"TensorDictTokenizer",
165+
"TokenizedDatasetLoader",
166+
"Binary",
167+
"BinaryDiscreteTensorSpec",
168+
"Bounded",
169+
"BoundedContinuous",
170+
"BoundedTensorSpec",
171+
"Categorical",
172+
"Choice",
173+
"Composite",
174+
"CompositeSpec",
175+
"DEVICE_TYPING",
176+
"DiscreteTensorSpec",
177+
"LazyStackedCompositeSpec",
178+
"LazyStackedTensorSpec",
179+
"MultiCategorical",
180+
"MultiDiscreteTensorSpec",
181+
"MultiOneHot",
182+
"MultiOneHotDiscreteTensorSpec",
183+
"NonTensor",
184+
"NonTensorSpec",
185+
"OneHot",
186+
"OneHotDiscreteTensorSpec",
187+
"Stacked",
188+
"StackedComposite",
189+
"TensorSpec",
190+
"Unbounded",
191+
"UnboundedContinuous",
192+
"UnboundedContinuousTensorSpec",
193+
"UnboundedDiscrete",
194+
"UnboundedDiscreteTensorSpec",
195+
"check_no_exclusive_keys",
196+
"consolidate_spec",
197+
"contains_lazy_spec",
198+
]

torchrl/envs/custom/chess.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010

1111
import torch
1212
from tensordict import TensorDict, TensorDictBase
13-
from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded
13+
from torchrl.data.tensor_specs import (
14+
Binary,
15+
Bounded,
16+
Categorical,
17+
Composite,
18+
NonTensor,
19+
Unbounded,
20+
)
1421
from torchrl.envs import EnvBase
1522
from torchrl.envs.common import _EnvPostInit
1623
from torchrl.envs.utils import _classproperty

torchrl/envs/custom/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from tensordict.tensorclass import NonTensorData, NonTensorStack
1212
from tensordict.utils import _zip_strict
1313
from torch.utils.data import DataLoader
14-
from torchrl.data import (
14+
from torchrl.data.map.hash import SipHash
15+
from torchrl.data.tensor_specs import (
1516
Bounded,
1617
Categorical as CategoricalSpec,
1718
Composite,
1819
NonTensor,
19-
SipHash,
2020
TensorSpec,
2121
Unbounded,
2222
)

torchrl/envs/libs/isaacgym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313
import torch
1414
from tensordict import TensorDictBase
15-
from torchrl.data import Composite
15+
from torchrl.data.tensor_specs import Composite
1616
from torchrl.envs.libs.gym import GymWrapper
1717
from torchrl.envs.utils import _classproperty, make_composite_from_td
1818

torchrl/envs/libs/meltingpot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from tensordict import TensorDict, TensorDictBase
1212

13-
from torchrl.data import Categorical, Composite, TensorSpec
13+
from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec
1414
from torchrl.envs.common import _EnvWrapper
1515
from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform
1616
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType

torchrl/envs/libs/openml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
from tensordict import TensorDict, TensorDictBase
9-
from torchrl.data.replay_buffers import SamplerWithoutReplacement
9+
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
1010

1111
from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
1212
from torchrl.envs.common import EnvBase

0 commit comments

Comments
 (0)