Skip to content

Commit b2cec25

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent a12530b commit b2cec25

File tree

3 files changed

+296
-56
lines changed

3 files changed

+296
-56
lines changed

test/test_transforms.py

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tensordict.tensordict
2222
import torch
23+
from tensordict.nn import WrapModule
2324

2425
from torchrl.collectors import MultiSyncDataCollector
2526

@@ -13208,7 +13209,9 @@ def test_single_trans_env_check(self):
1320813209
r = env.rollout(1000, policy_odd, break_when_all_done=True)
1320913210
assert r.shape[0] == 15
1321013211
assert (r["action"] == 0).all()
13211-
assert (r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)).all()
13212+
assert (
13213+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
13214+
).all()
1321213215
assert r["next", "done"].any()
1321313216

1321413217
# Player 1
@@ -13221,58 +13224,104 @@ def test_single_trans_env_check(self):
1322113224
r = env.rollout(1000, policy_even, break_when_all_done=True)
1322213225
assert r.shape[0] == 16
1322313226
assert (r["action"] == 1).all()
13224-
assert (r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)).all()
13227+
assert (
13228+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
13229+
).all()
1322513230
assert r["next", "done"].any()
1322613231

13232+
def _create_policy_odd(self, base_env):
13233+
return WrapModule(
13234+
lambda td, base_env=base_env: td.set(
13235+
"action", base_env.action_spec_unbatched.zero(td.shape)
13236+
),
13237+
out_keys=["action"],
13238+
)
1322713239

13228-
def test_trans_serial_env_check(self):
13229-
def make_env(max_count):
13230-
def make():
13231-
base_env = CountingEnv(max_steps=max_count)
13232-
transforms =
13233-
return base_env.append_transform(transforms)
13234-
return make
13235-
13236-
base_env = SerialEnv(3,
13237-
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)])
13238-
condition = lambda td: ((td.get("step_count") % 2) == 0)
13239-
policy_odd = lambda td, base_env=base_env: td.set("action", base_env.action_spec.zero())
13240-
policy_even = lambda td, base_env=base_env: td.set("action", base_env.action_spec.one())
13241-
env = base_env.append_transform(Compose(
13242-
StepCounter(),
13243-
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13244-
))
13245-
r = env.rollout(100, break_when_all_done=False)
13246-
print(r["step_count"].squeeze())
13240+
def _create_policy_even(self, base_env):
13241+
return WrapModule(
13242+
lambda td, base_env=base_env: td.set(
13243+
"action", base_env.action_spec_unbatched.one(td.shape)
13244+
),
13245+
out_keys=["action"],
13246+
)
13247+
13248+
def _create_transforms(self, condition, policy_even):
13249+
return Compose(
13250+
StepCounter(),
13251+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13252+
)
1324713253

13254+
def _make_env(self, max_count, env_cls):
13255+
torch.manual_seed(0)
13256+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13257+
base_env = env_cls(max_steps=max_count)
13258+
policy_even = self._create_policy_even(base_env)
13259+
transforms = self._create_transforms(condition, policy_even)
13260+
return base_env.append_transform(transforms)
13261+
13262+
def _test_env(self, env, policy_odd):
13263+
env.check_env_specs()
13264+
env.set_seed(0)
13265+
r = env.rollout(100, policy_odd, break_when_any_done=False)
13266+
# Check results are independent: one reset / step in one env should not impact results in another
13267+
r0, r1, r2 = r.unbind(0)
13268+
r0_split = r0.split(6)
13269+
assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
13270+
r1_split = r1.split(7)
13271+
assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
13272+
r2_split = r2.split(8)
13273+
assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))
13274+
13275+
def test_trans_serial_env_check(self):
13276+
torch.manual_seed(0)
13277+
base_env = SerialEnv(
13278+
3,
13279+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13280+
batch_locked=False,
13281+
)
13282+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13283+
policy_odd = self._create_policy_odd(base_env)
13284+
policy_even = self._create_policy_even(base_env)
13285+
transforms = self._create_transforms(condition, policy_even)
13286+
env = base_env.append_transform(transforms)
13287+
self._test_env(env, policy_odd)
1324813288

1324913289
def test_trans_parallel_env_check(self):
13250-
"""tests that a transformed paprallel env (TransformedEnv(ParallelEnv(N, lambda: env()), transform)) passes the check_env_specs test."""
13251-
raise NotImplementedError
13290+
torch.manual_seed(0)
13291+
base_env = ParallelEnv(
13292+
3,
13293+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13294+
batch_locked=False,
13295+
mp_start_method=mp_ctx,
13296+
)
13297+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13298+
policy_odd = self._create_policy_odd(base_env)
13299+
policy_even = self._create_policy_even(base_env)
13300+
transforms = self._create_transforms(condition, policy_even)
13301+
env = base_env.append_transform(transforms)
13302+
self._test_env(env, policy_odd)
1325213303

1325313304
def test_serial_trans_env_check(self):
13254-
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13255-
# Player 0
13256-
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13257-
policy_even = lambda td: td.set("action", env.action_spec.one())
13305+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13306+
policy_odd = self._create_policy_odd(CountingEnv())
13307+
1325813308
def make_env(max_count):
13259-
def make():
13260-
base_env = CountingEnv(max_steps=max_count)
13261-
transforms = Compose(
13262-
StepCounter(),
13263-
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13264-
)
13265-
return base_env.append_transform(transforms)
13266-
return make
13309+
return partial(self._make_env, max_count, CountingEnv)
1326713310

13268-
env = SerialEnv(3,
13269-
[make_env(6), make_env(7), make_env(8)])
13270-
r = env.rollout(100, break_when_all_done=False)
13271-
print(r["step_count"].squeeze())
13311+
env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
13312+
self._test_env(env, policy_odd)
1327213313

1327313314
def test_parallel_trans_env_check(self):
13274-
"""tests that a parallel transformed env (ParallelEnv(N, lambda: TransformedEnv(env, transform))) passes the check_env_specs test."""
13275-
raise NotImplementedError
13315+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13316+
policy_odd = self._create_policy_odd(CountingEnv())
13317+
13318+
def make_env(max_count):
13319+
return partial(self._make_env, max_count, CountingEnv)
13320+
13321+
env = ParallelEnv(
13322+
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
13323+
)
13324+
self._test_env(env, policy_odd)
1327613325

1327713326
def test_transform_no_env(self):
1327813327
"""tests the transform on dummy data, without an env."""

torchrl/envs/batched_envs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ class BatchedEnvBase(EnvBase):
191191
one of the environment has dynamic specs.
192192
193193
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
194+
batch_locked (bool, optional): if provided, will override the ``batch_locked`` attribute of the
195+
nested environments. `batch_locked=False` may allow for partial steps.
194196
195197
.. note::
196198
One can pass keyword arguments to each sub-environments using the following
@@ -305,6 +307,7 @@ def __init__(
305307
non_blocking: bool = False,
306308
mp_start_method: str = None,
307309
use_buffers: bool = None,
310+
batch_locked: bool | None = None,
308311
):
309312
super().__init__(device=device)
310313
self.serial_for_single = serial_for_single
@@ -344,6 +347,7 @@ def __init__(
344347

345348
# if share_individual_td is None, we will assess later if the output can be stacked
346349
self.share_individual_td = share_individual_td
350+
self._batch_locked = batch_locked
347351
self._share_memory = shared_memory
348352
self._memmap = memmap
349353
self.allow_step_when_done = allow_step_when_done
@@ -610,8 +614,8 @@ def map_device(key, value, device_map=device_map):
610614
self._env_tensordict.named_apply(
611615
map_device, nested_keys=True, filter_empty=True
612616
)
613-
614-
self._batch_locked = meta_data.batch_locked
617+
if self._batch_locked is None:
618+
self._batch_locked = meta_data.batch_locked
615619
else:
616620
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
617621
devices = set()
@@ -652,7 +656,8 @@ def map_device(key, value, device_map=device_map):
652656
self._env_tensordict = torch.stack(
653657
[meta_data.tensordict for meta_data in meta_data], 0
654658
)
655-
self._batch_locked = meta_data[0].batch_locked
659+
if self._batch_locked is None:
660+
self._batch_locked = meta_data[0].batch_locked
656661
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
657662

658663
def state_dict(self) -> OrderedDict:

0 commit comments

Comments
 (0)