Skip to content

Commit ef1bf20

Browse files
[Feature] Added batch_lock attribute in EnvBase (#399)
1 parent ea66abf commit ef1bf20

File tree

7 files changed

+248
-8
lines changed

7 files changed

+248
-8
lines changed

test/mocking_classes.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,79 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa
126126
return self.step(tensordict)
127127

128128

129+
class MockBatchedLockedEnv(EnvBase):
130+
"""Mocks an env whose batch_size defines the size of the output tensordict"""
131+
132+
def __init__(self, device, batch_size=None):
133+
super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size)
134+
self.action_spec = NdUnboundedContinuousTensorSpec((1,))
135+
self.input_spec = CompositeSpec(
136+
action=NdUnboundedContinuousTensorSpec((1,)),
137+
observation=NdUnboundedContinuousTensorSpec((1,)),
138+
)
139+
self.observation_spec = CompositeSpec(
140+
next_observation=NdUnboundedContinuousTensorSpec((1,))
141+
)
142+
self.reward_spec = NdUnboundedContinuousTensorSpec((1,))
143+
self.counter = 0
144+
145+
set_seed = MockSerialEnv.set_seed
146+
rand_step = MockSerialEnv.rand_step
147+
148+
def _step(self, tensordict):
149+
self.counter += 1
150+
# We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
151+
n = (
152+
torch.full(tensordict.batch_size, self.counter)
153+
.to(self.device)
154+
.to(torch.get_default_dtype())
155+
)
156+
done = self.counter >= self.max_val
157+
done = torch.full(
158+
tensordict.batch_size, done, dtype=torch.bool, device=self.device
159+
)
160+
161+
return TensorDict(
162+
{"reward": n, "done": done, "next_observation": n}, tensordict.batch_size
163+
)
164+
165+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
166+
self.max_val = max(self.counter + 100, self.counter * 2)
167+
if tensordict is None:
168+
batch_size = self.batch_size
169+
else:
170+
batch_size = tensordict.batch_size
171+
172+
n = (
173+
torch.full(batch_size, self.counter)
174+
.to(self.device)
175+
.to(torch.get_default_dtype())
176+
)
177+
done = self.counter >= self.max_val
178+
done = torch.full(batch_size, done, dtype=torch.bool, device=self.device)
179+
180+
return TensorDict(
181+
{"reward": n, "done": done, "next_observation": n}, batch_size
182+
)
183+
184+
185+
class MockBatchedUnLockedEnv(MockBatchedLockedEnv):
186+
"""Mocks an env whose batch_size does not define the size of the output tensordict.
187+
188+
The size of the output tensordict is defined by the input tensordict itself.
189+
190+
"""
191+
192+
def __init__(self, device, batch_size=None):
193+
super(MockBatchedUnLockedEnv, self).__init__(
194+
batch_size=batch_size, device=device
195+
)
196+
197+
@classmethod
198+
def __new__(cls, *args, **kwargs):
199+
return super().__new__(cls, *args, _batch_locked=False, **kwargs)
200+
201+
129202
class DiscreteActionVecMockEnv(_MockEnv):
130203
size = 7
131204
observation_spec = CompositeSpec(

test/test_env.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
DiscreteActionVecMockEnv,
1717
MockSerialEnv,
1818
DiscreteActionConvMockEnv,
19+
MockBatchedLockedEnv,
20+
MockBatchedUnLockedEnv,
1921
)
2022
from scipy.stats import chisquare
2123
from torch import nn
@@ -992,6 +994,58 @@ def test_steptensordict(
992994
assert out is next_tensordict
993995

994996

997+
@pytest.mark.parametrize("device", get_available_devices())
998+
def test_batch_locked(device):
999+
env = MockBatchedLockedEnv(device)
1000+
assert env.batch_locked
1001+
1002+
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
1003+
env.batch_locked = False
1004+
td = env.reset()
1005+
td["action"] = env.action_spec.rand(env.batch_size)
1006+
td_expanded = td.expand(2).clone()
1007+
td = env.step(td)
1008+
1009+
with pytest.raises(
1010+
RuntimeError, match="Expected a tensordict with shape==env.shape, "
1011+
):
1012+
env.step(td_expanded)
1013+
1014+
1015+
@pytest.mark.parametrize("device", get_available_devices())
1016+
def test_batch_unlocked(device):
1017+
env = MockBatchedUnLockedEnv(device)
1018+
assert not env.batch_locked
1019+
1020+
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
1021+
env.batch_locked = False
1022+
td = env.reset()
1023+
td["action"] = env.action_spec.rand(env.batch_size)
1024+
td_expanded = td.expand(2).clone()
1025+
td = env.step(td)
1026+
1027+
env.step(td_expanded)
1028+
1029+
1030+
@pytest.mark.parametrize("device", get_available_devices())
1031+
def test_batch_unlocked_with_batch_size(device):
1032+
env = MockBatchedUnLockedEnv(device, batch_size=torch.Size([2]))
1033+
assert not env.batch_locked
1034+
1035+
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
1036+
env.batch_locked = False
1037+
1038+
td = env.reset()
1039+
td["action"] = env.action_spec.rand(env.batch_size)
1040+
td_expanded = td.expand(2, 2).reshape(-1).to_tensordict()
1041+
td = env.step(td)
1042+
1043+
with pytest.raises(
1044+
RuntimeError, match="Expected a tensordict with shape==env.shape, "
1045+
):
1046+
env.step(td_expanded)
1047+
1048+
9951049
if __name__ == "__main__":
9961050
args, unknown = argparse.ArgumentParser().parse_known_args()
9971051
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_transforms.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import pytest
99
import torch
1010
from _utils_internal import get_available_devices
11-
from mocking_classes import ContinuousActionVecMockEnv, DiscreteActionConvMockEnvNumpy
11+
from mocking_classes import (
12+
ContinuousActionVecMockEnv,
13+
DiscreteActionConvMockEnvNumpy,
14+
MockBatchedLockedEnv,
15+
MockBatchedUnLockedEnv,
16+
)
1217
from torch import Tensor
1318
from torch import multiprocessing as mp
1419
from torchrl import prod
@@ -35,6 +40,7 @@
3540
RewardScaling,
3641
BinarizeReward,
3742
R3MTransform,
43+
RewardClipping,
3844
)
3945
from torchrl.envs.libs.gym import _has_gym, GymEnv
4046
from torchrl.envs.transforms import VecNorm, TransformedEnv
@@ -1365,6 +1371,75 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
13651371
assert set(expected_keys) == set(transformed_env.rollout(3).keys())
13661372

13671373

1374+
@pytest.mark.parametrize("device", get_available_devices())
1375+
def test_batch_locked_transformed(device):
1376+
env = TransformedEnv(
1377+
MockBatchedLockedEnv(device),
1378+
Compose(
1379+
ObservationNorm(keys_in=["next_observation"], loc=0.5, scale=1.1),
1380+
RewardClipping(0, 0.1),
1381+
),
1382+
)
1383+
assert env.batch_locked
1384+
1385+
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
1386+
env.batch_locked = False
1387+
td = env.reset()
1388+
td["action"] = env.action_spec.rand(env.batch_size)
1389+
td_expanded = td.expand(2).clone()
1390+
td = env.step(td)
1391+
1392+
with pytest.raises(
1393+
RuntimeError, match="Expected a tensordict with shape==env.shape, "
1394+
):
1395+
env.step(td_expanded)
1396+
1397+
1398+
@pytest.mark.parametrize("device", get_available_devices())
1399+
def test_batch_unlocked_transformed(device):
1400+
env = TransformedEnv(
1401+
MockBatchedUnLockedEnv(device),
1402+
Compose(
1403+
ObservationNorm(keys_in=["next_observation"], loc=0.5, scale=1.1),
1404+
RewardClipping(0, 0.1),
1405+
),
1406+
)
1407+
assert not env.batch_locked
1408+
1409+
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
1410+
env.batch_locked = False
1411+
td = env.reset()
1412+
td["action"] = env.action_spec.rand(env.batch_size)
1413+
td_expanded = td.expand(2).clone()
1414+
td = env.step(td)
1415+
env.step(td_expanded)
1416+
1417+
1418+
@pytest.mark.parametrize("device", get_available_devices())
1419+
def test_batch_unlocked_with_batch_size_transformed(device):
1420+
env = TransformedEnv(
1421+
MockBatchedUnLockedEnv(device, batch_size=torch.Size([2])),
1422+
Compose(
1423+
ObservationNorm(keys_in=["next_observation"], loc=0.5, scale=1.1),
1424+
RewardClipping(0, 0.1),
1425+
),
1426+
)
1427+
assert not env.batch_locked
1428+
1429+
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
1430+
env.batch_locked = False
1431+
1432+
td = env.reset()
1433+
td["action"] = env.action_spec.rand(env.batch_size)
1434+
td_expanded = td.expand(2, 2).reshape(-1).to_tensordict()
1435+
td = env.step(td)
1436+
1437+
with pytest.raises(
1438+
RuntimeError, match="Expected a tensordict with shape==env.shape, "
1439+
):
1440+
env.step(td_expanded)
1441+
1442+
13681443
if __name__ == "__main__":
13691444
args, unknown = argparse.ArgumentParser().parse_known_args()
13701445
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/common.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def __init__(
4949
batch_size: torch.Size,
5050
env_str: str,
5151
device: torch.device,
52+
batch_locked: bool = True,
5253
):
5354
self.tensordict = tensordict
5455
self.specs = specs
5556
self.batch_size = batch_size
5657
self.env_str = env_str
5758
self.device = device
59+
self.batch_locked = batch_locked
5860

5961
@staticmethod
6062
def build_metadata_from_env(env) -> EnvMetaData:
@@ -64,19 +66,27 @@ def build_metadata_from_env(env) -> EnvMetaData:
6466
batch_size = env.batch_size
6567
env_str = str(env)
6668
device = env.device
67-
return EnvMetaData(tensordict, specs, batch_size, env_str, device)
69+
batch_locked = env.batch_locked
70+
return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked)
6871

6972
def expand(self, *size: int) -> EnvMetaData:
7073
tensordict = self.tensordict.expand(*size).to_tensordict()
7174
batch_size = torch.Size([*size])
7275
return EnvMetaData(
73-
tensordict, self.specs, batch_size, self.env_str, self.device
76+
tensordict,
77+
self.specs,
78+
batch_size,
79+
self.env_str,
80+
self.device,
81+
self.batch_locked,
7482
)
7583

7684
def to(self, device: DEVICE_TYPING) -> EnvMetaData:
7785
tensordict = self.tensordict.to(device)
7886
specs = self.specs.to(device)
79-
return EnvMetaData(tensordict, specs, self.batch_size, self.env_str, device)
87+
return EnvMetaData(
88+
tensordict, specs, self.batch_size, self.env_str, device, self.batch_locked
89+
)
8090

8191
def __setstate__(self, state):
8292
state["tensordict"] = state["tensordict"].to_tensordict().to(state["device"])
@@ -218,10 +228,24 @@ def __init__(
218228
self.batch_size = torch.Size([])
219229

220230
@classmethod
221-
def __new__(cls, *args, **kwargs):
231+
def __new__(cls, *args, _batch_locked=True, **kwargs):
222232
cls._inplace_update = True
233+
cls._batch_locked = _batch_locked
223234
return super().__new__(cls)
224235

236+
@property
237+
def batch_locked(self) -> bool:
238+
"""
239+
Whether the environnement can be used with a batch size different from the one it was initialized with or not.
240+
If True, the env needs to be used with a tensordict having the same batch size as the env.
241+
batch_locked is an immutable property.
242+
"""
243+
return self._batch_locked
244+
245+
@batch_locked.setter
246+
def batch_locked(self, value: bool) -> None:
247+
raise RuntimeError("batch_locked is a read-only property")
248+
225249
@property
226250
def action_spec(self) -> TensorSpec:
227251
return self._action_spec
@@ -272,6 +296,8 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
272296
"""
273297

274298
# sanity check
299+
self._assert_tensordict_shape(tensordict)
300+
275301
if tensordict.get("action").dtype is not self.action_spec.dtype:
276302
raise TypeError(
277303
f"expected action.dtype to be {self.action_spec.dtype} "
@@ -408,7 +434,9 @@ def set_state(self):
408434
raise NotImplementedError
409435

410436
def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
411-
if tensordict.batch_size != self.batch_size:
437+
if tensordict.batch_size != self.batch_size and (
438+
self.batch_locked or self.batch_size != torch.Size([])
439+
):
412440
raise RuntimeError(
413441
f"Expected a tensordict with shape==env.shape, "
414442
f"got {tensordict.batch_size} and {self.batch_size}"
@@ -531,7 +559,9 @@ def policy(td):
531559
else:
532560
raise Exception("reset env before calling rollout!")
533561

534-
out_td = torch.stack(tensordicts, len(self.batch_size))
562+
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
563+
564+
out_td = torch.stack(tensordicts, len(batch_size))
535565
if return_contiguous:
536566
return out_td.contiguous()
537567
return out_td

torchrl/envs/gym_like.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class GymLikeEnv(_EnvWrapper):
7878
@classmethod
7979
def __new__(cls, *args, **kwargs):
8080
cls._info_dict_reader = None
81-
return super().__new__(cls, *args, **kwargs)
81+
return super().__new__(cls, *args, _batch_locked=True, **kwargs)
8282

8383
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
8484
action = tensordict.get("action")

torchrl/envs/transforms/transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def __init__(
312312
device = kwargs["device"]
313313
super().__init__(**kwargs)
314314
self._set_env(env, device)
315+
self._inplace_update = env._inplace_update
315316
if transform is None:
316317
transform = Compose()
317318
transform.set_parent(self)
@@ -328,6 +329,11 @@ def __init__(
328329
self._observation_spec = None
329330
self.batch_size = self.base_env.batch_size
330331

332+
def __new__(cls, env, *args, **kwargs):
333+
return super().__new__(
334+
cls, env, *args, _batch_locked=env.batch_locked, **kwargs
335+
)
336+
331337
def _set_env(self, env: EnvBase, device) -> None:
332338
self.base_env = env.to(device)
333339
# updates need not be inplace, as transforms may modify values out-place

torchrl/envs/vec_env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def _set_properties(self):
280280
self._dummy_env_str = self.meta_data.env_str
281281
self._device = self.meta_data.device
282282
self._env_tensordict = self.meta_data.tensordict
283+
self._batch_locked = self.meta_data.batch_locked
283284
else:
284285
self._batch_size = torch.Size(
285286
[self.num_workers, *self.meta_data[0].batch_size]
@@ -300,6 +301,7 @@ def _set_properties(self):
300301
self._env_tensordict = torch.stack(
301302
[meta_data.tensordict for meta_data in self.meta_data], 0
302303
)
304+
self._batch_locked = self.meta_data[0].batch_locked
303305

304306
def state_dict(self) -> OrderedDict:
305307
raise NotImplementedError

0 commit comments

Comments
 (0)