Skip to content

Commit b1cc796

Browse files
author
Vincent Moens
authored
[Feature] Fine control over devices in collectors (#1835)
1 parent 6277226 commit b1cc796

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2334
-986
lines changed

benchmarks/ecosystem/gym_env_throughput.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def make(envname=envname, gym_backend=gym_backend):
115115
frames_per_batch=1024,
116116
total_frames=num_workers * 10_000,
117117
device=device,
118-
storing_device=device,
119118
)
120119
pbar = tqdm.tqdm(total=num_workers * 10_000)
121120
total_frames = 0
@@ -178,7 +177,6 @@ def make_env(envname=envname, gym_backend=gym_backend):
178177
frames_per_batch=1024,
179178
total_frames=num_workers * 10_000,
180179
device=device,
181-
storing_device=device,
182180
)
183181
pbar = tqdm.tqdm(total=num_workers * 10_000)
184182
total_frames = 0
@@ -222,7 +220,6 @@ def make_env(
222220
total_frames=num_workers * 10_000,
223221
num_sub_threads=num_workers // num_collectors,
224222
device=device,
225-
storing_device=device,
226223
)
227224
pbar = tqdm.tqdm(total=num_workers * 10_000)
228225
total_frames = 0
@@ -260,7 +257,6 @@ def make_env(envname=envname, gym_backend=gym_backend):
260257
frames_per_batch=1024,
261258
total_frames=num_workers * 10_000,
262259
device=device,
263-
storing_device=device,
264260
)
265261
pbar = tqdm.tqdm(total=num_workers * 10_000)
266262
total_frames = 0

test/mocking_classes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import torch
88
import torch.nn as nn
9-
from tensordict.tensordict import TensorDict, TensorDictBase
9+
from tensordict import TensorDict, TensorDictBase
10+
from tensordict.nn import TensorDictModuleBase
1011
from tensordict.utils import expand_right, NestedKey
1112

1213
from torchrl.data.tensor_specs import (
@@ -229,6 +230,7 @@ def _step(self, tensordict):
229230
"observation": n.clone(),
230231
},
231232
batch_size=[],
233+
device=self.device,
232234
)
233235

234236
def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
@@ -240,7 +242,9 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
240242
done = self.counter >= self.max_val
241243
done = torch.tensor([done], dtype=torch.bool, device=self.device)
242244
return TensorDict(
243-
{"done": done, "terminated": done.clone(), "observation": n}, []
245+
{"done": done, "terminated": done.clone(), "observation": n},
246+
[],
247+
device=self.device,
244248
)
245249

246250
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
@@ -1374,8 +1378,9 @@ def _step(
13741378
return tensordict
13751379

13761380

1377-
class HeteroCountingEnvPolicy:
1381+
class HeterogeneousCountingEnvPolicy(TensorDictModuleBase):
13781382
def __init__(self, full_action_spec: TensorSpec, count: bool = True):
1383+
super().__init__()
13791384
self.full_action_spec = full_action_spec
13801385
self.count = count
13811386

@@ -1386,7 +1391,7 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase:
13861391
return td.update(action_td)
13871392

13881393

1389-
class HeteroCountingEnv(EnvBase):
1394+
class HeterogeneousCountingEnv(EnvBase):
13901395
"""A heterogeneous, counting Env."""
13911396

13921397
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
@@ -1569,13 +1574,14 @@ def _set_seed(self, seed: Optional[int]):
15691574
torch.manual_seed(seed)
15701575

15711576

1572-
class MultiKeyCountingEnvPolicy:
1577+
class MultiKeyCountingEnvPolicy(TensorDictModuleBase):
15731578
def __init__(
15741579
self,
15751580
full_action_spec: TensorSpec,
15761581
count: bool = True,
15771582
deterministic: bool = False,
15781583
):
1584+
super().__init__()
15791585
if not deterministic and not count:
15801586
raise ValueError("Not counting policy is always deterministic")
15811587

0 commit comments

Comments
 (0)