|
52 | 52 | FrameSkipTransform,
|
53 | 53 | GrayScale,
|
54 | 54 | gSDENoise,
|
| 55 | + InitTracker, |
55 | 56 | NoopResetEnv,
|
56 | 57 | ObservationNorm,
|
57 | 58 | ParallelEnv,
|
58 | 59 | PinMemoryTransform,
|
59 | 60 | R3MTransform,
|
60 | 61 | RandomCropTensorDict,
|
| 62 | + RenameTransform, |
61 | 63 | Resize,
|
62 | 64 | RewardClipping,
|
63 | 65 | RewardScaling,
|
|
76 | 78 | from torchrl.envs.libs.gym import _has_gym, GymEnv
|
77 | 79 | from torchrl.envs.transforms import VecNorm
|
78 | 80 | from torchrl.envs.transforms.r3m import _R3MNet
|
79 |
| -from torchrl.envs.transforms.transforms import _has_tv, InitTracker |
| 81 | +from torchrl.envs.transforms.transforms import _has_tv |
80 | 82 | from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
|
81 | 83 | from torchrl.envs.utils import check_env_specs, step_mdp
|
82 | 84 |
|
@@ -6315,6 +6317,292 @@ def test_crop_mask(self, mask_key):
|
6315 | 6317 | assert tensordict_crop[mask_key].all()
|
6316 | 6318 |
|
6317 | 6319 |
|
| 6320 | +@pytest.mark.parametrize("create_copy", [True, False]) |
| 6321 | +class TestRenameTransform(TransformBase): |
| 6322 | + def test_single_trans_env_check(self, create_copy): |
| 6323 | + env = TransformedEnv( |
| 6324 | + ContinuousActionVecMockEnv(), |
| 6325 | + RenameTransform( |
| 6326 | + [ |
| 6327 | + "observation", |
| 6328 | + ], |
| 6329 | + [ |
| 6330 | + "stuff", |
| 6331 | + ], |
| 6332 | + create_copy=create_copy, |
| 6333 | + ), |
| 6334 | + ) |
| 6335 | + check_env_specs(env) |
| 6336 | + env = TransformedEnv( |
| 6337 | + ContinuousActionVecMockEnv(), |
| 6338 | + RenameTransform( |
| 6339 | + ["observation_orig"], |
| 6340 | + ["stuff"], |
| 6341 | + ["observation_orig"], |
| 6342 | + [ |
| 6343 | + "stuff", |
| 6344 | + ], |
| 6345 | + create_copy=create_copy, |
| 6346 | + ), |
| 6347 | + ) |
| 6348 | + check_env_specs(env) |
| 6349 | + |
| 6350 | + def test_serial_trans_env_check(self, create_copy): |
| 6351 | + def make_env(): |
| 6352 | + return TransformedEnv( |
| 6353 | + ContinuousActionVecMockEnv(), |
| 6354 | + RenameTransform( |
| 6355 | + [ |
| 6356 | + "observation", |
| 6357 | + ], |
| 6358 | + [ |
| 6359 | + "stuff", |
| 6360 | + ], |
| 6361 | + create_copy=create_copy, |
| 6362 | + ), |
| 6363 | + ) |
| 6364 | + |
| 6365 | + env = SerialEnv(2, make_env) |
| 6366 | + check_env_specs(env) |
| 6367 | + |
| 6368 | + def make_env(): |
| 6369 | + return TransformedEnv( |
| 6370 | + ContinuousActionVecMockEnv(), |
| 6371 | + RenameTransform( |
| 6372 | + ["observation_orig"], |
| 6373 | + ["stuff"], |
| 6374 | + ["observation_orig"], |
| 6375 | + [ |
| 6376 | + "stuff", |
| 6377 | + ], |
| 6378 | + create_copy=create_copy, |
| 6379 | + ), |
| 6380 | + ) |
| 6381 | + |
| 6382 | + env = SerialEnv(2, make_env) |
| 6383 | + check_env_specs(env) |
| 6384 | + |
| 6385 | + def test_parallel_trans_env_check(self, create_copy): |
| 6386 | + def make_env(): |
| 6387 | + return TransformedEnv( |
| 6388 | + ContinuousActionVecMockEnv(), |
| 6389 | + RenameTransform( |
| 6390 | + [ |
| 6391 | + "observation", |
| 6392 | + ], |
| 6393 | + [ |
| 6394 | + "stuff", |
| 6395 | + ], |
| 6396 | + create_copy=create_copy, |
| 6397 | + ), |
| 6398 | + ) |
| 6399 | + |
| 6400 | + env = ParallelEnv(2, make_env) |
| 6401 | + check_env_specs(env) |
| 6402 | + |
| 6403 | + def make_env(): |
| 6404 | + return TransformedEnv( |
| 6405 | + ContinuousActionVecMockEnv(), |
| 6406 | + RenameTransform( |
| 6407 | + ["observation_orig"], |
| 6408 | + ["stuff"], |
| 6409 | + ["observation_orig"], |
| 6410 | + [ |
| 6411 | + "stuff", |
| 6412 | + ], |
| 6413 | + create_copy=create_copy, |
| 6414 | + ), |
| 6415 | + ) |
| 6416 | + |
| 6417 | + env = ParallelEnv(2, make_env) |
| 6418 | + check_env_specs(env) |
| 6419 | + |
| 6420 | + def test_trans_serial_env_check(self, create_copy): |
| 6421 | + def make_env(): |
| 6422 | + return ContinuousActionVecMockEnv() |
| 6423 | + |
| 6424 | + env = TransformedEnv( |
| 6425 | + SerialEnv(2, make_env), |
| 6426 | + RenameTransform( |
| 6427 | + [ |
| 6428 | + "observation", |
| 6429 | + ], |
| 6430 | + [ |
| 6431 | + "stuff", |
| 6432 | + ], |
| 6433 | + create_copy=create_copy, |
| 6434 | + ), |
| 6435 | + ) |
| 6436 | + check_env_specs(env) |
| 6437 | + env = TransformedEnv( |
| 6438 | + SerialEnv(2, make_env), |
| 6439 | + RenameTransform( |
| 6440 | + ["observation_orig"], |
| 6441 | + ["stuff"], |
| 6442 | + ["observation_orig"], |
| 6443 | + [ |
| 6444 | + "stuff", |
| 6445 | + ], |
| 6446 | + create_copy=create_copy, |
| 6447 | + ), |
| 6448 | + ) |
| 6449 | + check_env_specs(env) |
| 6450 | + |
| 6451 | + def test_trans_parallel_env_check(self, create_copy): |
| 6452 | + def make_env(): |
| 6453 | + return ContinuousActionVecMockEnv() |
| 6454 | + |
| 6455 | + env = TransformedEnv( |
| 6456 | + ParallelEnv(2, make_env), |
| 6457 | + RenameTransform( |
| 6458 | + [ |
| 6459 | + "observation", |
| 6460 | + ], |
| 6461 | + [ |
| 6462 | + "stuff", |
| 6463 | + ], |
| 6464 | + create_copy=create_copy, |
| 6465 | + ), |
| 6466 | + ) |
| 6467 | + check_env_specs(env) |
| 6468 | + env = TransformedEnv( |
| 6469 | + ParallelEnv(2, make_env), |
| 6470 | + RenameTransform( |
| 6471 | + ["observation_orig"], |
| 6472 | + ["stuff"], |
| 6473 | + ["observation_orig"], |
| 6474 | + [ |
| 6475 | + "stuff", |
| 6476 | + ], |
| 6477 | + create_copy=create_copy, |
| 6478 | + ), |
| 6479 | + ) |
| 6480 | + check_env_specs(env) |
| 6481 | + |
| 6482 | + @pytest.mark.parametrize("mode", ["forward", "_call"]) |
| 6483 | + def test_transform_no_env(self, create_copy, mode): |
| 6484 | + t = RenameTransform(["a"], ["b"], create_copy=create_copy) |
| 6485 | + tensordict = TensorDict({"a": torch.randn(())}, []) |
| 6486 | + if mode == "forward": |
| 6487 | + t(tensordict) |
| 6488 | + elif mode == "_call": |
| 6489 | + t._call(tensordict) |
| 6490 | + else: |
| 6491 | + raise NotImplementedError |
| 6492 | + assert "b" in tensordict.keys() |
| 6493 | + if create_copy: |
| 6494 | + assert "a" in tensordict.keys() |
| 6495 | + else: |
| 6496 | + assert "a" not in tensordict.keys() |
| 6497 | + |
| 6498 | + @pytest.mark.parametrize("mode", ["forward", "_call"]) |
| 6499 | + def test_transform_compose(self, create_copy, mode): |
| 6500 | + t = Compose(RenameTransform(["a"], ["b"], create_copy=create_copy)) |
| 6501 | + tensordict = TensorDict({"a": torch.randn(())}, []) |
| 6502 | + if mode == "forward": |
| 6503 | + t(tensordict) |
| 6504 | + elif mode == "_call": |
| 6505 | + t._call(tensordict) |
| 6506 | + else: |
| 6507 | + raise NotImplementedError |
| 6508 | + assert "b" in tensordict.keys() |
| 6509 | + if create_copy: |
| 6510 | + assert "a" in tensordict.keys() |
| 6511 | + else: |
| 6512 | + assert "a" not in tensordict.keys() |
| 6513 | + |
| 6514 | + def test_transform_env(self, create_copy): |
| 6515 | + env = TransformedEnv( |
| 6516 | + ContinuousActionVecMockEnv(), |
| 6517 | + RenameTransform( |
| 6518 | + [ |
| 6519 | + "observation", |
| 6520 | + ], |
| 6521 | + [ |
| 6522 | + "stuff", |
| 6523 | + ], |
| 6524 | + create_copy=create_copy, |
| 6525 | + ), |
| 6526 | + ) |
| 6527 | + r = env.rollout(3) |
| 6528 | + if create_copy: |
| 6529 | + assert "observation" in r.keys() |
| 6530 | + assert ("next", "observation") in r.keys(True) |
| 6531 | + else: |
| 6532 | + assert "observation" not in r.keys() |
| 6533 | + assert ("next", "observation") not in r.keys(True) |
| 6534 | + assert "stuff" in r.keys() |
| 6535 | + assert ("next", "stuff") in r.keys(True) |
| 6536 | + |
| 6537 | + env = TransformedEnv( |
| 6538 | + ContinuousActionVecMockEnv(), |
| 6539 | + RenameTransform( |
| 6540 | + ["observation_orig"], |
| 6541 | + ["stuff"], |
| 6542 | + ["observation_orig"], |
| 6543 | + [ |
| 6544 | + "stuff", |
| 6545 | + ], |
| 6546 | + create_copy=create_copy, |
| 6547 | + ), |
| 6548 | + ) |
| 6549 | + r = env.rollout(3) |
| 6550 | + if create_copy: |
| 6551 | + assert "observation_orig" in r.keys() |
| 6552 | + assert ("next", "observation_orig") in r.keys(True) |
| 6553 | + else: |
| 6554 | + assert "observation_orig" not in r.keys() |
| 6555 | + assert ("next", "observation_orig") not in r.keys(True) |
| 6556 | + assert "stuff" in r.keys() |
| 6557 | + assert ("next", "stuff") in r.keys(True) |
| 6558 | + |
| 6559 | + def test_transform_model(self, create_copy): |
| 6560 | + t = RenameTransform(["a"], ["b"], create_copy=create_copy) |
| 6561 | + tensordict = TensorDict({"a": torch.randn(())}, []) |
| 6562 | + model = nn.Sequential(t) |
| 6563 | + model(tensordict) |
| 6564 | + assert "b" in tensordict.keys() |
| 6565 | + if create_copy: |
| 6566 | + assert "a" in tensordict.keys() |
| 6567 | + else: |
| 6568 | + assert "a" not in tensordict.keys() |
| 6569 | + |
| 6570 | + @pytest.mark.parametrize( |
| 6571 | + "inverse", |
| 6572 | + [ |
| 6573 | + False, |
| 6574 | + True, |
| 6575 | + ], |
| 6576 | + ) |
| 6577 | + def test_transform_rb(self, create_copy, inverse): |
| 6578 | + if not inverse: |
| 6579 | + t = RenameTransform(["a"], ["b"], create_copy=create_copy) |
| 6580 | + tensordict = TensorDict({"a": torch.randn(())}, []).expand(10) |
| 6581 | + else: |
| 6582 | + t = RenameTransform(["a"], ["b"], ["a"], ["b"], create_copy=create_copy) |
| 6583 | + tensordict = TensorDict({"b": torch.randn(())}, []).expand(10) |
| 6584 | + rb = ReplayBuffer(LazyTensorStorage(20)) |
| 6585 | + rb.append_transform(t) |
| 6586 | + rb.extend(tensordict) |
| 6587 | + assert "a" in rb._storage._storage.keys() |
| 6588 | + sample = rb.sample(2) |
| 6589 | + if create_copy: |
| 6590 | + assert "a" in sample.keys() |
| 6591 | + else: |
| 6592 | + assert "a" not in sample.keys() |
| 6593 | + assert "b" in sample.keys() |
| 6594 | + |
| 6595 | + def test_transform_inverse(self, create_copy): |
| 6596 | + t = RenameTransform(["a"], ["b"], ["a"], ["b"], create_copy=create_copy) |
| 6597 | + tensordict = TensorDict({"b": torch.randn(())}, []).expand(10) |
| 6598 | + tensordict = t.inv(tensordict) |
| 6599 | + assert "a" in tensordict.keys() |
| 6600 | + if create_copy: |
| 6601 | + assert "b" in tensordict.keys() |
| 6602 | + else: |
| 6603 | + assert "b" not in tensordict.keys() |
| 6604 | + |
| 6605 | + |
6318 | 6606 | class TestInitTracker(TransformBase):
|
6319 | 6607 | def test_single_trans_env_check(self):
|
6320 | 6608 | env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])
|
|
0 commit comments