Skip to content

Commit e8f54eb

Browse files
authored
[Feature] VIPRewardTransform (#658)
1 parent 213ae5b commit e8f54eb

File tree

3 files changed

+114
-2
lines changed

3 files changed

+114
-2
lines changed

test/test_transforms.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
TensorDictPrimer,
5858
UnsqueezeTransform,
5959
)
60-
from torchrl.envs.transforms.vip import _VIPNet
60+
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
6161

6262
if _has_gym:
6363
import gym
@@ -1687,6 +1687,71 @@ def test_vip_parallel(self, model, device):
16871687
transformed_env.close()
16881688
del transformed_env
16891689

1690+
def test_vip_parallel_reward(self, model, device):
1691+
keys_in = ["next_pixels"]
1692+
keys_out = ["next_vec"]
1693+
tensor_pixels_key = None
1694+
vip = VIPRewardTransform(
1695+
model,
1696+
keys_in=keys_in,
1697+
keys_out=keys_out,
1698+
tensor_pixels_keys=tensor_pixels_key,
1699+
)
1700+
base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device))
1701+
transformed_env = TransformedEnv(base_env, vip)
1702+
tensordict_reset = TensorDict(
1703+
{"goal_image": torch.randint(0, 255, (4, 7, 7, 3), dtype=torch.uint8)},
1704+
[4],
1705+
device=device,
1706+
)
1707+
with pytest.raises(
1708+
KeyError,
1709+
match=r"VIPRewardTransform.* requires .* key to be present in the input tensordict",
1710+
):
1711+
_ = transformed_env.reset()
1712+
with pytest.raises(
1713+
KeyError,
1714+
match=r"VIPRewardTransform.* requires .* key to be present in the input tensordict",
1715+
):
1716+
_ = transformed_env.reset(tensordict_reset.select())
1717+
1718+
td = transformed_env.reset(tensordict_reset)
1719+
assert td.device == device
1720+
assert td.batch_size == torch.Size([4])
1721+
exp_keys = {"vec", "done", "pixels_orig", "goal_embedding", "goal_image"}
1722+
if tensor_pixels_key:
1723+
exp_keys.add(tensor_pixels_key)
1724+
assert set(td.keys()) == exp_keys
1725+
1726+
td = transformed_env.rand_step(td)
1727+
exp_keys = exp_keys.union({"next_vec", "next_pixels_orig", "action", "reward"})
1728+
assert set(td.keys()) == exp_keys, td
1729+
1730+
tensordict_reset = TensorDict(
1731+
{"goal_image": torch.randint(0, 255, (4, 7, 7, 3), dtype=torch.uint8)},
1732+
[4],
1733+
device=device,
1734+
)
1735+
td = transformed_env.rollout(
1736+
3, auto_reset=False, tensordict=transformed_env.reset(tensordict_reset)
1737+
)
1738+
assert set(td.keys()) == exp_keys, td
1739+
# test that we do compute the reward we want
1740+
cur_embedding = td["next_vec"]
1741+
goal_embedding = td["goal_embedding"]
1742+
last_embedding = td["vec"]
1743+
explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - (
1744+
-torch.norm(last_embedding - goal_embedding, dim=-1)
1745+
)
1746+
torch.testing.assert_close(explicit_reward, td["reward"].squeeze())
1747+
# test that there is only one goal embedding
1748+
goal = td["goal_embedding"]
1749+
goal_expand = td["goal_embedding"][:, :1].expand_as(td["goal_embedding"])
1750+
torch.testing.assert_close(goal, goal_expand)
1751+
1752+
transformed_env.close()
1753+
del transformed_env
1754+
16901755
@pytest.mark.parametrize("del_keys", [True, False])
16911756
@pytest.mark.parametrize(
16921757
"in_keys",

torchrl/envs/transforms/transforms.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,11 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
431431
return self.base_env.set_seed(seed, static_seed=static_seed)
432432

433433
def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
434-
out_tensordict = self.base_env.reset(execute_step=False, **kwargs)
434+
if tensordict is not None:
435+
tensordict = tensordict.clone(recurse=False)
436+
out_tensordict = self.base_env.reset(
437+
tensordict=tensordict, execute_step=False, **kwargs
438+
)
435439
out_tensordict = self.transform.reset(out_tensordict)
436440
out_tensordict = self.transform(out_tensordict)
437441
return out_tensordict

torchrl/envs/transforms/vip.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CompositeSpec,
1515
NdUnboundedContinuousTensorSpec,
1616
)
17+
from torchrl.data.tensordict.tensordict import TensorDictBase
1718
from torchrl.envs.transforms import (
1819
ToTensorImage,
1920
Compose,
@@ -306,3 +307,45 @@ def dtype(self):
306307
transform_reward_spec = _init_first(Compose.transform_reward_spec)
307308
reset = _init_first(Compose.reset)
308309
init = _init_first(Compose.init)
310+
311+
312+
class VIPRewardTransform(VIPTransform):
313+
"""A VIP transform to compute rewards based on embedded similarity.
314+
315+
This class will update the reward computation
316+
"""
317+
318+
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
319+
if "goal_embedding" not in tensordict.keys():
320+
tensordict = self._embed_goal(tensordict)
321+
return super().reset(tensordict)
322+
323+
def _embed_goal(self, tensordict):
324+
if "goal_image" not in tensordict.keys():
325+
raise KeyError(
326+
f"{self.__class__.__name__}.reset() requires a `'goal_image'` key to be "
327+
f"present in the input tensordict."
328+
)
329+
tensordict_in = tensordict.select("goal_image").rename_key(
330+
"goal_image", self.keys_in[0]
331+
)
332+
tensordict_in = super(VIPRewardTransform, self).forward(tensordict_in)
333+
tensordict = tensordict.update(
334+
tensordict_in.rename_key(self.keys_out[0], "goal_embedding")
335+
)
336+
return tensordict
337+
338+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
339+
if "goal_embedding" not in tensordict.keys():
340+
tensordict = self._embed_goal(tensordict)
341+
tensordict = super().forward(tensordict)
342+
cur_embedding = tensordict.get(self.keys_out[0])
343+
last_embedding_key = self.keys_out[0].split("next_")[1]
344+
last_embedding = tensordict.get(last_embedding_key, None)
345+
if last_embedding is not None:
346+
goal_embedding = tensordict["goal_embedding"]
347+
reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - (
348+
-torch.norm(last_embedding - goal_embedding, dim=-1)
349+
)
350+
tensordict.set("reward", reward)
351+
return tensordict

0 commit comments

Comments
 (0)