Skip to content

Commit 8512476

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 6bb023d commit 8512476

File tree

12 files changed

+133
-130
lines changed

12 files changed

+133
-130
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
echo '::endgroup::'
3636
3737
echo '::group::Install lint tools'
38-
pip install --progress-bar=off pre-commit
38+
pip install --progress-bar=off pre-commit autoflake
3939
echo '::endgroup::'
4040
4141
echo '::group::Lint Python source and configs'

build_tools/setup_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .extension import CMakeBuild, get_ext_modules # noqa
6+
from .extension import CMakeBuild, get_ext_modules
77

88
__all__ = ["CMakeBuild", "get_ext_modules"]

check_future_imports.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

docs/source/reference/envs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ Recorders are transforms that register data as they come in, for logging purpose
12201220

12211221
Helpers
12221222
-------
1223-
.. currentmodule:: torchrl.envs.utils
1223+
.. currentmodule:: torchrl.envs
12241224

12251225
.. autosummary::
12261226
:toctree: generated/

docs/source/reference/objectives.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ auto-completion to make their choice.
111111
:template: rl_template_noinherit.rst
112112

113113
LossModule
114+
add_random_module
114115

115116
DQN
116117
---

test/test_cost.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
TD3BCLoss,
101101
TD3Loss,
102102
)
103-
from torchrl.objectives.common import LossModule
103+
from torchrl.objectives.common import add_random_module, LossModule
104104
from torchrl.objectives.deprecated import DoubleREDQLoss_deprecated, REDQLoss_deprecated
105105
from torchrl.objectives.redq import REDQLoss
106106
from torchrl.objectives.reinforce import ReinforceLoss
@@ -16162,6 +16162,15 @@ def _composite_log_prob(self):
1616216162
yield
1616316163
setter.unset()
1616416164

16165+
def test_add_random_module(self):
16166+
class MyMod(nn.Module):
16167+
...
16168+
16169+
add_random_module(MyMod)
16170+
import torchrl.objectives.utils
16171+
16172+
assert MyMod in torchrl.objectives.utils.RANDOM_MODULE_LIST
16173+
1616516174
def test_standardization(self):
1616616175
t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6)
1616716176
std_t0 = _standardize(t, exclude_dims=(1, 3))

torchrl/envs/__init__.py

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .batched_envs import ParallelEnv, SerialEnv
77
from .common import EnvBase, EnvMetaData, make_tensordict
8-
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
8+
from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
99
from .env_creator import env_creator, EnvCreator, get_env_metadata
1010
from .gym_like import default_info_dict_reader, GymLikeEnv
1111
from .libs import (
@@ -46,6 +46,8 @@
4646
from .transforms import (
4747
ActionDiscretizer,
4848
ActionMask,
49+
as_nested_tensor,
50+
as_padded_tensor,
4951
AutoResetEnv,
5052
AutoResetTransform,
5153
BatchSizeTransform,
@@ -58,6 +60,7 @@
5860
Compose,
5961
ConditionalSkip,
6062
Crop,
63+
DataLoadingPrimer,
6164
DeviceCastTransform,
6265
DiscreteActionProjection,
6366
DoubleToFloat,
@@ -116,134 +119,144 @@
116119
check_marl_grouping,
117120
exploration_type,
118121
ExplorationType,
122+
get_available_libraries,
119123
make_composite_from_td,
120124
MarlGroupMapType,
125+
RandomPolicy,
121126
set_exploration_type,
122127
step_mdp,
128+
terminated_or_truncated,
123129
)
124130

125131
__all__ = [
126-
"ParallelEnv",
127-
"SerialEnv",
128-
"EnvBase",
129-
"EnvMetaData",
130-
"make_tensordict",
131-
"ChessEnv",
132-
"LLMHashingEnv",
133-
"PendulumEnv",
134-
"TicTacToeEnv",
135-
"env_creator",
136-
"EnvCreator",
137-
"get_env_metadata",
138-
"default_info_dict_reader",
139-
"GymLikeEnv",
140-
"BraxEnv",
141-
"BraxWrapper",
142-
"DMControlEnv",
143-
"DMControlWrapper",
144-
"gym_backend",
145-
"GymEnv",
146-
"GymWrapper",
147-
"HabitatEnv",
148-
"IsaacGymEnv",
149-
"IsaacGymWrapper",
150-
"JumanjiEnv",
151-
"JumanjiWrapper",
152-
"MeltingpotEnv",
153-
"MeltingpotWrapper",
154-
"MOGymEnv",
155-
"MOGymWrapper",
156-
"MultiThreadedEnv",
157-
"MultiThreadedEnvWrapper",
158-
"OpenMLEnv",
159-
"OpenSpielEnv",
160-
"OpenSpielWrapper",
161-
"PettingZooEnv",
162-
"PettingZooWrapper",
163-
"register_gym_spec_conversion",
164-
"RoboHiveEnv",
165-
"set_gym_backend",
166-
"SMACv2Env",
167-
"SMACv2Wrapper",
168-
"UnityMLAgentsEnv",
169-
"UnityMLAgentsWrapper",
170-
"VmasEnv",
171-
"VmasWrapper",
172-
"DreamerDecoder",
173-
"DreamerEnv",
174-
"ModelBasedEnvBase",
175132
"ActionDiscretizer",
176133
"ActionMask",
177134
"AutoResetEnv",
178135
"AutoResetTransform",
179136
"BatchSizeTransform",
180137
"BinarizeReward",
138+
"BraxEnv",
139+
"BraxWrapper",
181140
"BurnInTransform",
182141
"CatFrames",
183142
"CatTensors",
184143
"CenterCrop",
144+
"ChessEnv",
185145
"ClipTransform",
186146
"Compose",
187147
"ConditionalSkip",
188148
"Crop",
149+
"DMControlEnv",
150+
"DMControlWrapper",
151+
"DTypeCastTransform",
152+
"DataLoadingPrimer",
189153
"DeviceCastTransform",
190154
"DiscreteActionProjection",
191155
"DoubleToFloat",
192-
"DTypeCastTransform",
156+
"DreamerDecoder",
157+
"DreamerEnv",
193158
"EndOfLifeTransform",
159+
"EnvBase",
160+
"EnvCreator",
161+
"EnvMetaData",
194162
"ExcludeTransform",
163+
"ExplorationType",
195164
"FiniteTensorDictCheck",
196165
"FlattenObservation",
197166
"FrameSkipTransform",
198167
"GrayScale",
199-
"gSDENoise",
168+
"GymEnv",
169+
"GymLikeEnv",
170+
"GymWrapper",
171+
"HabitatEnv",
200172
"Hash",
201173
"InitTracker",
174+
"IsaacGymEnv",
175+
"IsaacGymWrapper",
176+
"JumanjiEnv",
177+
"JumanjiWrapper",
202178
"KLRewardTransform",
179+
"LLMEnv",
180+
"LLMHashingEnv",
203181
"LineariseRewards",
182+
"MOGymEnv",
183+
"MOGymWrapper",
184+
"MarlGroupMapType",
185+
"MeltingpotEnv",
186+
"MeltingpotWrapper",
187+
"ModelBasedEnvBase",
204188
"MultiAction",
205189
"MultiStepTransform",
190+
"MultiThreadedEnv",
191+
"MultiThreadedEnvWrapper",
206192
"NoopResetEnv",
207193
"ObservationNorm",
208194
"ObservationTransform",
195+
"OpenMLEnv",
196+
"OpenSpielEnv",
197+
"OpenSpielWrapper",
198+
"ParallelEnv",
199+
"PendulumEnv",
209200
"PermuteTransform",
201+
"PettingZooEnv",
202+
"PettingZooWrapper",
210203
"PinMemoryTransform",
211204
"R3MTransform",
212205
"RandomCropTensorDict",
206+
"RandomPolicy",
213207
"RemoveEmptySpecs",
214208
"RenameTransform",
215209
"Resize",
216210
"Reward2GoTransform",
217211
"RewardClipping",
218212
"RewardScaling",
219213
"RewardSum",
214+
"RoboHiveEnv",
215+
"SMACv2Env",
216+
"SMACv2Wrapper",
220217
"SelectTransform",
218+
"SerialEnv",
221219
"SignTransform",
222220
"SqueezeTransform",
223221
"Stack",
224222
"StepCounter",
225223
"TargetReturn",
226224
"TensorDictPrimer",
225+
"TicTacToeEnv",
227226
"TimeMaxPool",
228227
"Timer",
229-
"Tokenizer",
230228
"ToTensorImage",
229+
"Tokenizer",
231230
"TrajCounter",
232231
"Transform",
233232
"TransformedEnv",
234233
"UnaryTransform",
234+
"UnityMLAgentsEnv",
235+
"UnityMLAgentsWrapper",
235236
"UnsqueezeTransform",
236237
"VC1Transform",
237-
"VecGymEnvTransform",
238-
"VecNorm",
239238
"VIPRewardTransform",
240239
"VIPTransform",
240+
"VecGymEnvTransform",
241+
"VecNorm",
242+
"VmasEnv",
243+
"VmasWrapper",
244+
"as_nested_tensor",
245+
"as_padded_tensor",
241246
"check_env_specs",
242247
"check_marl_grouping",
248+
"default_info_dict_reader",
249+
"env_creator",
243250
"exploration_type",
244-
"ExplorationType",
251+
"gSDENoise",
252+
"get_available_libraries",
253+
"get_env_metadata",
254+
"gym_backend",
245255
"make_composite_from_td",
246-
"MarlGroupMapType",
256+
"make_tensordict",
257+
"register_gym_spec_conversion",
247258
"set_exploration_type",
259+
"set_gym_backend",
248260
"step_mdp",
261+
"terminated_or_truncated",
249262
]

torchrl/envs/custom/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .chess import ChessEnv
7-
from .llm import LLMHashingEnv
7+
from .llm import LLMEnv, LLMHashingEnv
88
from .pendulum import PendulumEnv
99
from .tictactoeenv import TicTacToeEnv
1010

11-
__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv"]
11+
__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv", "LLMEnv"]

0 commit comments

Comments
 (0)