Skip to content

Commit 2127e60

Browse files
ymwdalexZhe Sun
andauthored
[BugFix] Add transform_observation_spec _R3MNet (#443)
* Add transform_observation_spec function and related test cases for _R3MNet * Improve unit tests: 1) check observation_spec from rollouts; 2) integrate transform_observation_spec tests inside TestR3M test suite; 3) fix some errors Co-authored-by: Zhe Sun <zhesun@fb.com>
1 parent fd97614 commit 2127e60

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,6 @@ version.py
169169
torchrl/version.py
170170
wandb
171171
outputs
172+
173+
# PyCharm
174+
.idea

test/test_transforms.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NdBoundedTensorSpec,
1717
CompositeSpec,
1818
UnboundedContinuousTensorSpec,
19+
NdUnboundedContinuousTensorSpec,
1920
)
2021
from torchrl.data import TensorDict
2122
from torchrl.envs import EnvCreator, SerialEnv
@@ -37,6 +38,7 @@
3738
)
3839
from torchrl.envs.libs.gym import _has_gym, GymEnv
3940
from torchrl.envs.transforms import VecNorm, TransformedEnv
41+
from torchrl.envs.transforms.r3m import _R3MNet
4042
from torchrl.envs.transforms.transforms import (
4143
_has_tv,
4244
NoopResetEnv,
@@ -1293,6 +1295,75 @@ def test_r3m_parallel(self, model, device):
12931295
transformed_env.close()
12941296
del transformed_env
12951297

1298+
@pytest.mark.parametrize("del_keys", [True, False])
1299+
@pytest.mark.parametrize(
1300+
"in_keys",
1301+
[["next_pixels"], ["next_pixels_1", "next_pixels_2", "next_pixels_3"]],
1302+
)
1303+
@pytest.mark.parametrize(
1304+
"out_keys",
1305+
[["next_r3m_vec"], ["next_r3m_vec_1", "next_r3m_vec_2", "next_r3m_vec_3"]],
1306+
)
1307+
def test_r3mnet_transform_observation_spec(
1308+
self, in_keys, out_keys, del_keys, device, model
1309+
):
1310+
r3m_net = _R3MNet(in_keys, out_keys, model, del_keys)
1311+
1312+
observation_spec = CompositeSpec(
1313+
**{key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys}
1314+
)
1315+
if del_keys:
1316+
exp_ts = CompositeSpec(
1317+
**{
1318+
key: NdUnboundedContinuousTensorSpec(r3m_net.outdim, device)
1319+
for key in out_keys
1320+
}
1321+
)
1322+
1323+
observation_spec_out = r3m_net.transform_observation_spec(observation_spec)
1324+
1325+
for key in in_keys:
1326+
assert key not in observation_spec_out
1327+
for key in out_keys:
1328+
assert observation_spec_out[key].shape == exp_ts[key].shape
1329+
assert observation_spec_out[key].device == exp_ts[key].device
1330+
assert observation_spec_out[key].dtype == exp_ts[key].dtype
1331+
else:
1332+
ts_dict = {}
1333+
for key in in_keys:
1334+
ts_dict[key] = observation_spec[key]
1335+
for key in out_keys:
1336+
ts_dict[key] = NdUnboundedContinuousTensorSpec(r3m_net.outdim, device)
1337+
exp_ts = CompositeSpec(**ts_dict)
1338+
1339+
observation_spec_out = r3m_net.transform_observation_spec(observation_spec)
1340+
1341+
for key in in_keys + out_keys:
1342+
assert observation_spec_out[key].shape == exp_ts[key].shape
1343+
assert observation_spec_out[key].dtype == exp_ts[key].dtype
1344+
assert observation_spec_out[key].device == exp_ts[key].device
1345+
1346+
@pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]])
1347+
def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
1348+
keys_in = ["next_pixels"]
1349+
keys_out = ["next_vec"]
1350+
r3m = R3MTransform(
1351+
model,
1352+
keys_in=keys_in,
1353+
keys_out=keys_out,
1354+
tensor_pixels_keys=tensor_pixels_key,
1355+
)
1356+
base_env = DiscreteActionConvMockEnvNumpy().to(device)
1357+
transformed_env = TransformedEnv(base_env, r3m)
1358+
expected_keys = (
1359+
list(transformed_env.input_spec.keys())
1360+
+ list(transformed_env.observation_spec.keys())
1361+
+ [key.strip("next_") for key in transformed_env.observation_spec.keys()]
1362+
+ ["reward"]
1363+
+ ["done"]
1364+
)
1365+
assert set(expected_keys) == set(transformed_env.rollout(3).keys())
1366+
12961367

12971368
if __name__ == "__main__":
12981369
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/transforms/r3m.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from torch.nn import Identity
66

77
from torchrl.data import TensorDict, DEVICE_TYPING
8+
from torchrl.data.tensor_specs import (
9+
TensorSpec,
10+
CompositeSpec,
11+
NdUnboundedContinuousTensorSpec,
12+
)
813
from torchrl.envs.transforms import (
914
ToTensorImage,
1015
Compose,
@@ -75,6 +80,24 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
7580
out = out.view(*shape, *out.shape[1:])
7681
return out
7782

83+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
84+
if not isinstance(observation_spec, CompositeSpec):
85+
raise ValueError("_R3MNet can only infer CompositeSpec")
86+
87+
keys = [key for key in observation_spec._specs.keys() if key in self.keys_in]
88+
device = observation_spec[keys[0]].device
89+
90+
if self.del_keys:
91+
for key_in in keys:
92+
del observation_spec[key_in]
93+
94+
for key_out in self.keys_out:
95+
observation_spec[key_out] = NdUnboundedContinuousTensorSpec(
96+
shape=torch.Size([self.outdim]), device=device
97+
)
98+
99+
return observation_spec
100+
78101
@staticmethod
79102
def _load_weights(model_name, r3m_instance, dir_prefix):
80103
if model_name not in ("r3m_50", "r3m_34", "r3m_18"):

0 commit comments

Comments
 (0)