Skip to content

Commit 35a7813

Browse files
author
Vincent Moens
committed
[Feature] flexible batch_locked for jumanji
ghstack-source-id: e356b65 Pull Request resolved: #2382
1 parent 14b63e4 commit 35a7813

File tree

4 files changed

+169
-39
lines changed

4 files changed

+169
-39
lines changed

test/test_libs.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ def test_jumanji_seeding(self, envname):
16051605

16061606
@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
16071607
def test_jumanji_batch_size(self, envname, batch_size):
1608-
env = JumanjiEnv(envname, batch_size=batch_size)
1608+
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
16091609
env.set_seed(0)
16101610
tdreset = env.reset()
16111611
tdrollout = env.rollout(max_steps=50)
@@ -1616,7 +1616,7 @@ def test_jumanji_batch_size(self, envname, batch_size):
16161616

16171617
@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
16181618
def test_jumanji_spec_rollout(self, envname, batch_size):
1619-
env = JumanjiEnv(envname, batch_size=batch_size)
1619+
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
16201620
env.set_seed(0)
16211621
check_env_specs(env)
16221622

@@ -1627,7 +1627,7 @@ def test_jumanji_consistency(self, envname, batch_size):
16271627
import numpy as onp
16281628
from torchrl.envs.libs.jax_utils import _tree_flatten
16291629

1630-
env = JumanjiEnv(envname, batch_size=batch_size)
1630+
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
16311631
obs_keys = list(env.observation_spec.keys(True))
16321632
env.set_seed(1)
16331633
rollout = env.rollout(10)
@@ -1665,19 +1665,38 @@ def test_jumanji_consistency(self, envname, batch_size):
16651665
@pytest.mark.parametrize("batch_size", [[3], []])
16661666
def test_jumanji_rendering(self, envname, batch_size):
16671667
# check that this works with a batch-size
1668-
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size)
1668+
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size, jit=True)
16691669
env.set_seed(0)
16701670
env.transform.transform_observation_spec(env.base_env.observation_spec)
16711671

16721672
r = env.rollout(10)
16731673
pixels = r["pixels"]
16741674
if not isinstance(pixels, torch.Tensor):
16751675
pixels = torch.as_tensor(np.asarray(pixels))
1676+
assert batch_size
1677+
else:
1678+
assert not batch_size
16761679
assert pixels.unique().numel() > 1
16771680
assert pixels.dtype == torch.uint8
16781681

16791682
check_env_specs(env)
16801683

1684+
@pytest.mark.parametrize("jit", [True, False])
1685+
def test_jumanji_batch_unlocked(self, envname, jit):
1686+
torch.manual_seed(0)
1687+
env = JumanjiEnv(envname, jit=jit)
1688+
env.set_seed(0)
1689+
assert not env.batch_locked
1690+
reset = env.reset(TensorDict(batch_size=[16]))
1691+
assert reset.batch_size == (16,)
1692+
env.rand_step(reset)
1693+
r = env.rollout(
1694+
2000, auto_reset=False, tensordict=reset, break_when_all_done=True
1695+
)
1696+
assert r.batch_size[0] == 16
1697+
done = r["next", "done"]
1698+
assert done.any(-2).all() or (r.shape[-1] == 2000)
1699+
16811700

16821701
ENVPOOL_CLASSIC_CONTROL_ENVS = [
16831702
PENDULUM_VERSIONED(),

torchrl/envs/common.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,21 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
15001500
"""
15011501
# sanity check
15021502
self._assert_tensordict_shape(tensordict)
1503+
partial_steps = None
1504+
1505+
if not self.batch_locked:
1506+
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
1507+
partial_steps = tensordict.get("_step", None)
1508+
if partial_steps is not None:
1509+
if partial_steps.all():
1510+
partial_steps = None
1511+
else:
1512+
tensordict_batch_size = tensordict.batch_size
1513+
partial_steps = partial_steps.view(tensordict_batch_size)
1514+
tensordict = tensordict[partial_steps]
1515+
else:
1516+
tensordict_batch_size = self.batch_size
1517+
15031518
next_preset = tensordict.get("next", None)
15041519

15051520
next_tensordict = self._step(tensordict)
@@ -1512,6 +1527,10 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
15121527
next_preset.exclude(*next_tensordict.keys(True, True))
15131528
)
15141529
tensordict.set("next", next_tensordict)
1530+
if partial_steps is not None:
1531+
result = tensordict.new_zeros(tensordict_batch_size)
1532+
result[partial_steps] = tensordict
1533+
return result
15151534
return tensordict
15161535

15171536
@classmethod
@@ -2731,7 +2750,7 @@ def _rollout_stop_early(
27312750
if break_when_all_done:
27322751
if partial_steps is not True:
27332752
# At least one partial step has been done
2734-
del td_append["_partial_steps"]
2753+
del td_append["_step"]
27352754
td_append = torch.where(
27362755
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
27372756
)
@@ -2757,17 +2776,17 @@ def _rollout_stop_early(
27572776
_terminated_or_truncated(
27582777
tensordict,
27592778
full_done_spec=self.output_spec["full_done_spec"],
2760-
key="_partial_steps",
2779+
key="_step",
27612780
write_full_false=False,
27622781
)
2763-
partial_step_curr = tensordict.get("_partial_steps", None)
2782+
partial_step_curr = tensordict.get("_step", None)
27642783
if partial_step_curr is not None:
27652784
partial_step_curr = ~partial_step_curr
27662785
partial_steps = partial_steps & partial_step_curr
27672786
if partial_steps is not True:
27682787
if not partial_steps.any():
27692788
break
2770-
tensordict.set("_partial_steps", partial_steps)
2789+
tensordict.set("_step", partial_steps)
27712790

27722791
if callback is not None:
27732792
callback(self, tensordict)

torchrl/envs/libs/jax_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,21 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
102102
return None
103103

104104

105-
def _tensordict_to_object(tensordict: TensorDictBase, object_example):
105+
def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None):
106106
"""Converts a TensorDict to a namedtuple or a dataclass."""
107107
from jax import dlpack as jax_dlpack, numpy as jnp
108108

109+
if batch_size is None:
110+
batch_size = []
109111
t = {}
110112
_fields = _get_object_fields(object_example)
111113
for name, example in _fields.items():
112114
value = tensordict.get(name, None)
113115
if isinstance(value, TensorDictBase):
114-
t[name] = _tensordict_to_object(value, example)
116+
t[name] = _tensordict_to_object(value, example, batch_size=batch_size)
115117
elif value is None:
116118
if isinstance(example, dict):
117-
t[name] = _tensordict_to_object({}, example)
119+
t[name] = _tensordict_to_object({}, example, batch_size=batch_size)
118120
else:
119121
t[name] = None
120122
else:
@@ -140,7 +142,9 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example):
140142
t[name] = value
141143
else:
142144
value = jnp.reshape(value, tuple(shape))
143-
t[name] = value.view(example.dtype).reshape(example.shape)
145+
t[name] = value.view(example.dtype).reshape(
146+
(*batch_size, *example.shape)
147+
)
144148
return type(object_example)(**t)
145149

146150

0 commit comments

Comments
 (0)