|
16 | 16 | NdBoundedTensorSpec,
|
17 | 17 | CompositeSpec,
|
18 | 18 | UnboundedContinuousTensorSpec,
|
| 19 | + NdUnboundedContinuousTensorSpec, |
19 | 20 | )
|
20 | 21 | from torchrl.data import TensorDict
|
21 | 22 | from torchrl.envs import EnvCreator, SerialEnv
|
|
37 | 38 | )
|
38 | 39 | from torchrl.envs.libs.gym import _has_gym, GymEnv
|
39 | 40 | from torchrl.envs.transforms import VecNorm, TransformedEnv
|
| 41 | +from torchrl.envs.transforms.r3m import _R3MNet |
40 | 42 | from torchrl.envs.transforms.transforms import (
|
41 | 43 | _has_tv,
|
42 | 44 | NoopResetEnv,
|
@@ -1293,6 +1295,75 @@ def test_r3m_parallel(self, model, device):
|
1293 | 1295 | transformed_env.close()
|
1294 | 1296 | del transformed_env
|
1295 | 1297 |
|
| 1298 | + @pytest.mark.parametrize("del_keys", [True, False]) |
| 1299 | + @pytest.mark.parametrize( |
| 1300 | + "in_keys", |
| 1301 | + [["next_pixels"], ["next_pixels_1", "next_pixels_2", "next_pixels_3"]], |
| 1302 | + ) |
| 1303 | + @pytest.mark.parametrize( |
| 1304 | + "out_keys", |
| 1305 | + [["next_r3m_vec"], ["next_r3m_vec_1", "next_r3m_vec_2", "next_r3m_vec_3"]], |
| 1306 | + ) |
| 1307 | + def test_r3mnet_transform_observation_spec( |
| 1308 | + self, in_keys, out_keys, del_keys, device, model |
| 1309 | + ): |
| 1310 | + r3m_net = _R3MNet(in_keys, out_keys, model, del_keys) |
| 1311 | + |
| 1312 | + observation_spec = CompositeSpec( |
| 1313 | + **{key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} |
| 1314 | + ) |
| 1315 | + if del_keys: |
| 1316 | + exp_ts = CompositeSpec( |
| 1317 | + **{ |
| 1318 | + key: NdUnboundedContinuousTensorSpec(r3m_net.outdim, device) |
| 1319 | + for key in out_keys |
| 1320 | + } |
| 1321 | + ) |
| 1322 | + |
| 1323 | + observation_spec_out = r3m_net.transform_observation_spec(observation_spec) |
| 1324 | + |
| 1325 | + for key in in_keys: |
| 1326 | + assert key not in observation_spec_out |
| 1327 | + for key in out_keys: |
| 1328 | + assert observation_spec_out[key].shape == exp_ts[key].shape |
| 1329 | + assert observation_spec_out[key].device == exp_ts[key].device |
| 1330 | + assert observation_spec_out[key].dtype == exp_ts[key].dtype |
| 1331 | + else: |
| 1332 | + ts_dict = {} |
| 1333 | + for key in in_keys: |
| 1334 | + ts_dict[key] = observation_spec[key] |
| 1335 | + for key in out_keys: |
| 1336 | + ts_dict[key] = NdUnboundedContinuousTensorSpec(r3m_net.outdim, device) |
| 1337 | + exp_ts = CompositeSpec(**ts_dict) |
| 1338 | + |
| 1339 | + observation_spec_out = r3m_net.transform_observation_spec(observation_spec) |
| 1340 | + |
| 1341 | + for key in in_keys + out_keys: |
| 1342 | + assert observation_spec_out[key].shape == exp_ts[key].shape |
| 1343 | + assert observation_spec_out[key].dtype == exp_ts[key].dtype |
| 1344 | + assert observation_spec_out[key].device == exp_ts[key].device |
| 1345 | + |
| 1346 | + @pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]]) |
| 1347 | + def test_r3m_spec_against_real(self, model, tensor_pixels_key, device): |
| 1348 | + keys_in = ["next_pixels"] |
| 1349 | + keys_out = ["next_vec"] |
| 1350 | + r3m = R3MTransform( |
| 1351 | + model, |
| 1352 | + keys_in=keys_in, |
| 1353 | + keys_out=keys_out, |
| 1354 | + tensor_pixels_keys=tensor_pixels_key, |
| 1355 | + ) |
| 1356 | + base_env = DiscreteActionConvMockEnvNumpy().to(device) |
| 1357 | + transformed_env = TransformedEnv(base_env, r3m) |
| 1358 | + expected_keys = ( |
| 1359 | + list(transformed_env.input_spec.keys()) |
| 1360 | + + list(transformed_env.observation_spec.keys()) |
| 1361 | + + [key.strip("next_") for key in transformed_env.observation_spec.keys()] |
| 1362 | + + ["reward"] |
| 1363 | + + ["done"] |
| 1364 | + ) |
| 1365 | + assert set(expected_keys) == set(transformed_env.rollout(3).keys()) |
| 1366 | + |
1296 | 1367 |
|
1297 | 1368 | if __name__ == "__main__":
|
1298 | 1369 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
0 commit comments