Skip to content

Commit f4709c1

Browse files
author
Vincent Moens
committed
[BugFix] Compatibility of tensordict primers with batched envs (specifically for LSTM and GRU)
ghstack-source-id: e1da58e Pull Request resolved: #2668
1 parent 133d709 commit f4709c1

File tree

6 files changed

+191
-48
lines changed

6 files changed

+191
-48
lines changed

sota-implementations/decision_transformer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
109109
)
110110

111111
# copy action from the input tensordict to the output
112-
transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec))
112+
transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec))
113113

114114
transformed_env.append_transform(DoubleToFloat())
115115
obsnorm = ObservationNorm(

test/test_tensordictmodules.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
import functools
78
import os
89

910
import pytest
@@ -12,6 +13,7 @@
1213
import torchrl.modules
1314
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
1415
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
16+
from tensordict.utils import assert_close
1517
from torch import nn
1618
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
1719
from torchrl.envs import (
@@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based):
938940
@pytest.mark.parametrize("python_based", [True, False])
939941
@pytest.mark.parametrize("parallel", [True, False])
940942
@pytest.mark.parametrize("heterogeneous", [True, False])
941-
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
943+
@pytest.mark.parametrize("within", [False, True])
944+
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within):
942945
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv
943946

944947
torch.manual_seed(0)
948+
num_envs = 3
945949
device = "cuda" if torch.cuda.device_count() else "cpu"
946950
# tests that hidden states are carried over with parallel envs
947951
lstm_module = LSTMModule(
@@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
958962
else:
959963
cls = SerialEnv
960964

961-
def create_transformed_env():
962-
primer = lstm_module.make_tensordict_primer()
963-
env = DiscreteActionVecMockEnv(
964-
categorical_action_encoding=True, device=device
965+
if within:
966+
967+
def create_transformed_env():
968+
primer = lstm_module.make_tensordict_primer()
969+
env = DiscreteActionVecMockEnv(
970+
categorical_action_encoding=True, device=device
971+
)
972+
env = TransformedEnv(env)
973+
env.append_transform(InitTracker())
974+
env.append_transform(primer)
975+
return env
976+
977+
else:
978+
create_transformed_env = functools.partial(
979+
DiscreteActionVecMockEnv,
980+
categorical_action_encoding=True,
981+
device=device,
965982
)
966-
env = TransformedEnv(env)
967-
env.append_transform(InitTracker())
968-
env.append_transform(primer)
969-
return env
970983

971984
if heterogeneous:
972985
create_transformed_env = [
973-
EnvCreator(create_transformed_env),
974-
EnvCreator(create_transformed_env),
986+
EnvCreator(create_transformed_env) for _ in range(num_envs)
975987
]
976988
env = cls(
977989
create_env_fn=create_transformed_env,
978-
num_workers=2,
990+
num_workers=num_envs,
979991
)
992+
if not within:
993+
env = env.append_transform(InitTracker())
994+
env.append_transform(lstm_module.make_tensordict_primer())
980995

981996
mlp = TensorDictModule(
982997
MLP(
@@ -1002,6 +1017,19 @@ def create_transformed_env():
10021017
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
10031018
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
10041019
assert (data.get("recurrent_state_c") != 0.0).any()
1020+
return data
1021+
1022+
@pytest.mark.parametrize("python_based", [True, False])
1023+
@pytest.mark.parametrize("parallel", [True, False])
1024+
@pytest.mark.parametrize("heterogeneous", [True, False])
1025+
def test_lstm_parallel_within(self, python_based, parallel, heterogeneous):
1026+
out_within = self.test_lstm_parallel_env(
1027+
python_based, parallel, heterogeneous, within=True
1028+
)
1029+
out_not_within = self.test_lstm_parallel_env(
1030+
python_based, parallel, heterogeneous, within=False
1031+
)
1032+
assert_close(out_within, out_not_within)
10051033

10061034
@pytest.mark.skipif(
10071035
not _has_functorch, reason="vmap can only be used with functorch"
@@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based):
13301358
@pytest.mark.parametrize("python_based", [True, False])
13311359
@pytest.mark.parametrize("parallel", [True, False])
13321360
@pytest.mark.parametrize("heterogeneous", [True, False])
1333-
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
1361+
@pytest.mark.parametrize("within", [False, True])
1362+
def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within):
13341363
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv
13351364

13361365
torch.manual_seed(0)
1366+
num_workers = 3
13371367

13381368
device = "cuda" if torch.cuda.device_count() else "cpu"
13391369
# tests that hidden states are carried over with parallel envs
@@ -1347,30 +1377,42 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
13471377
python_based=python_based,
13481378
)
13491379

1350-
def create_transformed_env():
1351-
primer = gru_module.make_tensordict_primer()
1352-
env = DiscreteActionVecMockEnv(
1353-
categorical_action_encoding=True, device=device
1380+
if within:
1381+
1382+
def create_transformed_env():
1383+
primer = gru_module.make_tensordict_primer()
1384+
env = DiscreteActionVecMockEnv(
1385+
categorical_action_encoding=True, device=device
1386+
)
1387+
env = TransformedEnv(env)
1388+
env.append_transform(InitTracker())
1389+
env.append_transform(primer)
1390+
return env
1391+
1392+
else:
1393+
create_transformed_env = functools.partial(
1394+
DiscreteActionVecMockEnv,
1395+
categorical_action_encoding=True,
1396+
device=device,
13541397
)
1355-
env = TransformedEnv(env)
1356-
env.append_transform(InitTracker())
1357-
env.append_transform(primer)
1358-
return env
13591398

13601399
if parallel:
13611400
cls = ParallelEnv
13621401
else:
13631402
cls = SerialEnv
13641403
if heterogeneous:
13651404
create_transformed_env = [
1366-
EnvCreator(create_transformed_env),
1367-
EnvCreator(create_transformed_env),
1405+
EnvCreator(create_transformed_env) for _ in range(num_workers)
13681406
]
13691407

1370-
env = cls(
1408+
env: ParallelEnv | SerialEnv = cls(
13711409
create_env_fn=create_transformed_env,
1372-
num_workers=2,
1410+
num_workers=num_workers,
13731411
)
1412+
if not within:
1413+
primer = gru_module.make_tensordict_primer()
1414+
env = env.append_transform(InitTracker())
1415+
env.append_transform(primer)
13741416

13751417
mlp = TensorDictModule(
13761418
MLP(
@@ -1396,6 +1438,19 @@ def create_transformed_env():
13961438
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
13971439
assert (data.get("recurrent_state") != 0.0).any()
13981440
assert (data.get(("next", "recurrent_state")) != 0.0).all()
1441+
return data
1442+
1443+
@pytest.mark.parametrize("python_based", [True, False])
1444+
@pytest.mark.parametrize("parallel", [True, False])
1445+
@pytest.mark.parametrize("heterogeneous", [True, False])
1446+
def test_gru_parallel_within(self, python_based, parallel, heterogeneous):
1447+
out_within = self.test_gru_parallel_env(
1448+
python_based, parallel, heterogeneous, within=True
1449+
)
1450+
out_not_within = self.test_gru_parallel_env(
1451+
python_based, parallel, heterogeneous, within=False
1452+
)
1453+
assert_close(out_within, out_not_within)
13991454

14001455
@pytest.mark.skipif(
14011456
not _has_functorch, reason="vmap can only be used with functorch"

test/test_transforms.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7408,7 +7408,7 @@ def make_env():
74087408
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
74097409
env = TransformedEnv(
74107410
maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv),
7411-
TensorDictPrimer(mykey=Unbounded([2, 4])),
7411+
TensorDictPrimer(mykey=Unbounded([4])),
74127412
)
74137413
try:
74147414
check_env_specs(env)
@@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
74237423
pass
74247424

74257425
@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
7426-
def test_trans_serial_env_check(self, spec_shape):
7427-
env = TransformedEnv(
7428-
SerialEnv(2, ContinuousActionVecMockEnv),
7429-
TensorDictPrimer(mykey=Unbounded(spec_shape)),
7430-
)
7426+
@pytest.mark.parametrize("expand_specs", [True, False, None])
7427+
def test_trans_serial_env_check(self, spec_shape, expand_specs):
7428+
if expand_specs is None:
7429+
with pytest.warns(FutureWarning, match=""):
7430+
env = TransformedEnv(
7431+
SerialEnv(2, ContinuousActionVecMockEnv),
7432+
TensorDictPrimer(
7433+
mykey=Unbounded(spec_shape), expand_specs=expand_specs
7434+
),
7435+
)
7436+
env.observation_spec
7437+
elif expand_specs is True:
7438+
shape = spec_shape[:-1]
7439+
env = TransformedEnv(
7440+
SerialEnv(2, ContinuousActionVecMockEnv),
7441+
TensorDictPrimer(
7442+
Composite(mykey=Unbounded(spec_shape), shape=shape),
7443+
expand_specs=expand_specs,
7444+
),
7445+
)
7446+
else:
7447+
# If we don't expand, we can't use [4]
7448+
env = TransformedEnv(
7449+
SerialEnv(2, ContinuousActionVecMockEnv),
7450+
TensorDictPrimer(
7451+
mykey=Unbounded(spec_shape), expand_specs=expand_specs
7452+
),
7453+
)
7454+
if spec_shape == [4]:
7455+
with pytest.raises(ValueError):
7456+
env.observation_spec
7457+
return
7458+
74317459
check_env_specs(env)
74327460
assert "mykey" in env.reset().keys()
74337461
r = env.rollout(3)
@@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env):
1031010338
transform = KLRewardTransform(actor, out_keys=out_key)
1031110339
return Compose(
1031210340
TensorDictPrimer(
10313-
primers={
10314-
"sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1])
10315-
}
10341+
sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]),
10342+
shape=base_env.shape,
1031610343
),
1031710344
transform,
1031810345
)

torchrl/envs/batched_envs.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
17441744
# We keep track of which keys are present to let the worker know what
17451745
# should be passed to the env (we don't want to pass done states for instance)
17461746
next_td_keys = list(next_td_passthrough.keys(True, True))
1747+
next_shared_tensordict_parent = shared_tensordict_parent.get("next")
1748+
1749+
# We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset.
1750+
# The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of
1751+
# the batched env but part of the specs of a transformed batched env.
1752+
# If that is the case, `update_` will fail to find the entries to update.
1753+
# What we do instead is keeping the tensors on the side and putting them back after completing _step.
1754+
keys_to_update, keys_to_copy = zip(
1755+
*[
1756+
(key, None)
1757+
if key in next_shared_tensordict_parent.keys(True, True)
1758+
else (None, key)
1759+
for key in next_td_keys
1760+
]
1761+
)
1762+
keys_to_update = [key for key in keys_to_update if key is not None]
1763+
keys_to_copy = [key for key in keys_to_copy if key is not None]
17471764
data = [
1748-
{"next_td_passthrough_keys": next_td_keys}
1765+
{"next_td_passthrough_keys": keys_to_update}
17491766
for _ in range(self.num_workers)
17501767
]
1751-
shared_tensordict_parent.get("next").update_(
1752-
next_td_passthrough, non_blocking=self.non_blocking
1753-
)
1768+
if keys_to_update:
1769+
next_shared_tensordict_parent.update_(
1770+
next_td_passthrough,
1771+
non_blocking=self.non_blocking,
1772+
keys_to_update=keys_to_update,
1773+
)
1774+
if keys_to_copy:
1775+
next_td_passthrough = next_td_passthrough.select(*keys_to_copy)
1776+
else:
1777+
next_td_passthrough = None
17541778
else:
1779+
next_td_passthrough = None
17551780
data = [{} for _ in range(self.num_workers)]
17561781

17571782
if self._non_tensor_keys:
@@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor):
18071832
LazyStackedTensorDict(*non_tensor_tds),
18081833
keys_to_update=self._non_tensor_keys,
18091834
)
1835+
if next_td_passthrough is not None:
1836+
out.update(next_td_passthrough)
1837+
18101838
self._sync_w2m()
18111839
if partial_steps is not None:
18121840
result = out.new_zeros(tensordict_save.shape)

torchrl/envs/transforms/transforms.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4984,6 +4984,7 @@ def __init__(
49844984
| Dict[NestedKey, float]
49854985
| Dict[NestedKey, Callable] = None,
49864986
reset_key: NestedKey | None = None,
4987+
expand_specs: bool = None,
49874988
**kwargs,
49884989
):
49894990
self.device = kwargs.pop("device", None)
@@ -4995,8 +4996,16 @@ def __init__(
49954996
)
49964997
kwargs = primers
49974998
if not isinstance(kwargs, Composite):
4998-
kwargs = Composite(kwargs)
4999-
self.primers = kwargs
4999+
shape = kwargs.pop("shape", None)
5000+
device = kwargs.pop("device", None)
5001+
if "batch_size" in kwargs.keys():
5002+
extra_kwargs = {"batch_size": kwargs.pop("batch_size")}
5003+
else:
5004+
extra_kwargs = {}
5005+
primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs)
5006+
self.primers = primers
5007+
self.expand_specs = expand_specs
5008+
50005009
if random and default_value:
50015010
raise ValueError(
50025011
"Setting random to True and providing a default_value are incompatible."
@@ -5089,12 +5098,26 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
50895098
)
50905099

50915100
if self.primers.shape != observation_spec.shape:
5092-
try:
5093-
# We try to set the primer shape to the observation spec shape
5094-
self.primers.shape = observation_spec.shape
5095-
except ValueError:
5096-
# If we fail, we expand them to that shape
5101+
if self.expand_specs:
50975102
self.primers = self._expand_shape(self.primers)
5103+
elif self.expand_specs is None:
5104+
warnings.warn(
5105+
f"expand_specs wasn't specified in the {type(self).__name__} constructor. "
5106+
f"The current behaviour is that the transform will attempt to set the shape of the composite "
5107+
f"spec, and if this can't be done it will be expanded. "
5108+
f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size "
5109+
f"will raise an exception.",
5110+
category=FutureWarning,
5111+
)
5112+
try:
5113+
# We try to set the primer shape to the observation spec shape
5114+
self.primers.shape = observation_spec.shape
5115+
except ValueError:
5116+
# If we fail, we expand them to that shape
5117+
self.primers = self._expand_shape(self.primers)
5118+
else:
5119+
self.primers.shape = observation_spec.shape
5120+
50985121
device = observation_spec.device
50995122
observation_spec.update(self.primers.clone().to(device))
51005123
return observation_spec

0 commit comments

Comments
 (0)