|
57 | 57 | TensorDictPrimer,
|
58 | 58 | UnsqueezeTransform,
|
59 | 59 | )
|
60 |
| -from torchrl.envs.transforms.vip import _VIPNet |
| 60 | +from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform |
61 | 61 |
|
62 | 62 | if _has_gym:
|
63 | 63 | import gym
|
@@ -1687,6 +1687,71 @@ def test_vip_parallel(self, model, device):
|
1687 | 1687 | transformed_env.close()
|
1688 | 1688 | del transformed_env
|
1689 | 1689 |
|
| 1690 | + def test_vip_parallel_reward(self, model, device): |
| 1691 | + keys_in = ["next_pixels"] |
| 1692 | + keys_out = ["next_vec"] |
| 1693 | + tensor_pixels_key = None |
| 1694 | + vip = VIPRewardTransform( |
| 1695 | + model, |
| 1696 | + keys_in=keys_in, |
| 1697 | + keys_out=keys_out, |
| 1698 | + tensor_pixels_keys=tensor_pixels_key, |
| 1699 | + ) |
| 1700 | + base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device)) |
| 1701 | + transformed_env = TransformedEnv(base_env, vip) |
| 1702 | + tensordict_reset = TensorDict( |
| 1703 | + {"goal_image": torch.randint(0, 255, (4, 7, 7, 3), dtype=torch.uint8)}, |
| 1704 | + [4], |
| 1705 | + device=device, |
| 1706 | + ) |
| 1707 | + with pytest.raises( |
| 1708 | + KeyError, |
| 1709 | + match=r"VIPRewardTransform.* requires .* key to be present in the input tensordict", |
| 1710 | + ): |
| 1711 | + _ = transformed_env.reset() |
| 1712 | + with pytest.raises( |
| 1713 | + KeyError, |
| 1714 | + match=r"VIPRewardTransform.* requires .* key to be present in the input tensordict", |
| 1715 | + ): |
| 1716 | + _ = transformed_env.reset(tensordict_reset.select()) |
| 1717 | + |
| 1718 | + td = transformed_env.reset(tensordict_reset) |
| 1719 | + assert td.device == device |
| 1720 | + assert td.batch_size == torch.Size([4]) |
| 1721 | + exp_keys = {"vec", "done", "pixels_orig", "goal_embedding", "goal_image"} |
| 1722 | + if tensor_pixels_key: |
| 1723 | + exp_keys.add(tensor_pixels_key) |
| 1724 | + assert set(td.keys()) == exp_keys |
| 1725 | + |
| 1726 | + td = transformed_env.rand_step(td) |
| 1727 | + exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"}) |
| 1728 | + assert set(td.keys()) == exp_keys, td |
| 1729 | + |
| 1730 | + tensordict_reset = TensorDict( |
| 1731 | + {"goal_image": torch.randint(0, 255, (4, 7, 7, 3), dtype=torch.uint8)}, |
| 1732 | + [4], |
| 1733 | + device=device, |
| 1734 | + ) |
| 1735 | + td = transformed_env.rollout( |
| 1736 | + 3, auto_reset=False, tensordict=transformed_env.reset(tensordict_reset) |
| 1737 | + ) |
| 1738 | + assert set(td.keys()) == exp_keys, td |
| 1739 | + # test that we do compute the reward we want |
| 1740 | + cur_embedding = td["next_vec"] |
| 1741 | + goal_embedding = td["goal_embedding"] |
| 1742 | + last_embedding = td["vec"] |
| 1743 | + explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - ( |
| 1744 | + -torch.norm(last_embedding - goal_embedding, dim=-1) |
| 1745 | + ) |
| 1746 | + torch.testing.assert_close(explicit_reward, td["reward"].squeeze()) |
| 1747 | + # test that there is only one goal embedding |
| 1748 | + goal = td["goal_embedding"] |
| 1749 | + goal_expand = td["goal_embedding"][:, :1].expand_as(td["goal_embedding"]) |
| 1750 | + torch.testing.assert_close(goal, goal_expand) |
| 1751 | + |
| 1752 | + transformed_env.close() |
| 1753 | + del transformed_env |
| 1754 | + |
1690 | 1755 | @pytest.mark.parametrize("del_keys", [True, False])
|
1691 | 1756 | @pytest.mark.parametrize(
|
1692 | 1757 | "in_keys",
|
|
0 commit comments