Skip to content

Commit 6bb023d

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent a137732 commit 6bb023d

File tree

9 files changed

+5
-20
lines changed

9 files changed

+5
-20
lines changed

torchrl/envs/libs/brax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import importlib.util
88
import warnings
99

10-
1110
import torch
1211
from packaging import version
1312
from tensordict import TensorDict, TensorDictBase

torchrl/envs/libs/jax_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def _tree_flatten(x, batch_size: torch.Size):
4141
}
4242

4343

44-
def _ndarray_to_tensor(
45-
value: jnp.ndarray | np.ndarray # noqa: F821
46-
) -> torch.Tensor:
44+
def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821
4745
from jax import dlpack as jax_dlpack, numpy as jnp
4846

4947
# JAX arrays generated by jax.vmap would have Numpy dtypes.

torchrl/envs/model_based/dreamer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
87
import torch
98
from tensordict import TensorDict
109
from tensordict.nn import TensorDictModule

torchrl/envs/transforms/functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
87
from torch import Tensor
98

109

torchrl/modules/distributions/discrete.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def entropy(self):
128128
return -p_log_p.sum(-1)
129129

130130
@_one_hot_wrapper(D.Categorical)
131-
def sample(
132-
self, sample_shape: torch.Size | Sequence | None = None
133-
) -> torch.Tensor:
131+
def sample(self, sample_shape: torch.Size | Sequence | None = None) -> torch.Tensor:
134132
...
135133

136134
def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:

torchrl/modules/distributions/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
87
import torch
98
from torch import autograd, distributions as d
109
from torch.distributions import Independent, Transform, TransformedDistribution

torchrl/objectives/value/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
87
import torch
98

109
from tensordict import TensorDictBase

torchrl/trainers/helpers/collectors.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,7 @@ def _make_collector(
250250

251251
def make_collector_offpolicy(
252252
make_env: Callable[[], EnvBase],
253-
actor_model_explore: (
254-
TensorDictModuleWrapper | ProbabilisticTensorDictSequential
255-
),
253+
actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential),
256254
cfg: DictConfig, # noqa: F821
257255
make_env_kwargs: dict | None = None,
258256
) -> DataCollectorBase:
@@ -314,9 +312,7 @@ def make_collector_offpolicy(
314312

315313
def make_collector_onpolicy(
316314
make_env: Callable[[], EnvBase],
317-
actor_model_explore: (
318-
TensorDictModuleWrapper | ProbabilisticTensorDictSequential
319-
),
315+
actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential),
320316
cfg: DictConfig, # noqa: F821
321317
make_env_kwargs: dict | None = None,
322318
) -> DataCollectorBase:

torchrl/trainers/helpers/trainers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ def make_trainer(
8282
loss_module: LossModule,
8383
recorder: EnvBase | None = None,
8484
target_net_updater: TargetNetUpdater | None = None,
85-
policy_exploration: None | (
86-
TensorDictModuleWrapper | TensorDictModule
87-
) = None,
85+
policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None,
8886
replay_buffer: ReplayBuffer | None = None,
8987
logger: Logger | None = None,
9088
cfg: DictConfig = None, # noqa: F821

0 commit comments

Comments
 (0)