Skip to content

Commit 32c4623

Browse files
committed
[Feature] Enable Hash.inv
ghstack-source-id: 9567081 Pull Request resolved: #2757
1 parent 4c06ce2 commit 32c4623

File tree

2 files changed

+141
-61
lines changed

2 files changed

+141
-61
lines changed

test/test_transforms.py

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from functools import partial
1717
from sys import platform
1818

19+
import numpy as np
20+
1921
import pytest
2022

2123
import tensordict.tensordict
@@ -2288,7 +2290,7 @@ class TestHash(TransformBase):
22882290
def test_transform_no_env(self, datatype):
22892291
if datatype == "tensor":
22902292
obs = torch.tensor(10)
2291-
hash_fn = hash
2293+
hash_fn = lambda x: torch.tensor(hash(x))
22922294
elif datatype == "str":
22932295
obs = "abcdefg"
22942296
hash_fn = Hash.reproducible_hash
@@ -2302,6 +2304,7 @@ def test_transform_no_env(self, datatype):
23022304
)
23032305

23042306
def fn0(x):
2307+
# return tuple([tuple(Hash.reproducible_hash(x_).tolist()) for x_ in x])
23052308
return torch.stack([Hash.reproducible_hash(x_) for x_ in x])
23062309

23072310
hash_fn = fn0
@@ -2334,7 +2337,7 @@ def test_single_trans_env_check(self, datatype):
23342337
t = Hash(
23352338
in_keys=["observation"],
23362339
out_keys=["hashing"],
2337-
hash_fn=hash,
2340+
hash_fn=lambda x: torch.tensor(hash(x)),
23382341
)
23392342
base_env = CountingEnv()
23402343
elif datatype == "str":
@@ -2353,7 +2356,7 @@ def make_env():
23532356
t = Hash(
23542357
in_keys=["observation"],
23552358
out_keys=["hashing"],
2356-
hash_fn=hash,
2359+
hash_fn=lambda x: torch.tensor(hash(x)),
23572360
)
23582361
base_env = CountingEnv()
23592362

@@ -2376,7 +2379,7 @@ def make_env():
23762379
t = Hash(
23772380
in_keys=["observation"],
23782381
out_keys=["hashing"],
2379-
hash_fn=hash,
2382+
hash_fn=lambda x: torch.tensor(hash(x)),
23802383
)
23812384
base_env = CountingEnv()
23822385
elif datatype == "str":
@@ -2402,7 +2405,7 @@ def test_trans_serial_env_check(self, datatype):
24022405
t = Hash(
24032406
in_keys=["observation"],
24042407
out_keys=["hashing"],
2405-
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2408+
hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]),
24062409
)
24072410
base_env = CountingEnv
24082411
elif datatype == "str":
@@ -2422,7 +2425,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
24222425
t = Hash(
24232426
in_keys=["observation"],
24242427
out_keys=["hashing"],
2425-
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2428+
hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]),
24262429
)
24272430
base_env = CountingEnv
24282431
elif datatype == "str":
@@ -2457,7 +2460,7 @@ def test_transform_compose(self, datatype):
24572460
t = Hash(
24582461
in_keys=["observation"],
24592462
out_keys=["hashing"],
2460-
hash_fn=hash,
2463+
hash_fn=lambda x: torch.tensor(hash(x)),
24612464
)
24622465
t = Compose(t)
24632466
td_hashed = t(td)
@@ -2469,7 +2472,7 @@ def test_transform_model(self):
24692472
t = Hash(
24702473
in_keys=[("next", "observation"), ("observation",)],
24712474
out_keys=[("next", "hashing"), ("hashing",)],
2472-
hash_fn=hash,
2475+
hash_fn=lambda x: torch.tensor(hash(x)),
24732476
)
24742477
model = nn.Sequential(t, nn.Identity())
24752478
td = TensorDict(
@@ -2486,7 +2489,7 @@ def test_transform_env(self):
24862489
t = Hash(
24872490
in_keys=["observation"],
24882491
out_keys=["hashing"],
2489-
hash_fn=hash,
2492+
hash_fn=lambda x: torch.tensor(hash(x)),
24902493
)
24912494
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
24922495
assert env.observation_spec["hashing"]
@@ -2499,7 +2502,7 @@ def test_transform_rb(self, rbclass):
24992502
t = Hash(
25002503
in_keys=[("next", "observation"), ("observation",)],
25012504
out_keys=[("next", "hashing"), ("hashing",)],
2502-
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2505+
hash_fn=lambda x: torch.tensor([hash(x[0]), hash(x[1])]),
25032506
)
25042507
rb = rbclass(storage=LazyTensorStorage(10))
25052508
rb.append_transform(t)
@@ -2519,18 +2522,73 @@ def test_transform_rb(self, rbclass):
25192522
assert "observation" in td.keys()
25202523
assert ("next", "observation") in td.keys(True)
25212524

2522-
def test_transform_inverse(self):
2523-
return
2524-
env = CountingEnv()
2525-
with pytest.raises(TypeError):
2526-
env = env.append_transform(
2527-
Hash(
2528-
in_keys=[],
2529-
out_keys=[],
2530-
in_keys_inv=["action"],
2531-
out_keys_inv=["action_hash"],
2532-
)
2533-
)
2525+
@pytest.mark.parametrize("repertoire_gen", [lambda: None, lambda: {}])
2526+
def test_transform_inverse(self, repertoire_gen):
2527+
repertoire = repertoire_gen()
2528+
t = Hash(
2529+
in_keys=["observation"],
2530+
out_keys=["hashing"],
2531+
in_keys_inv=["observation"],
2532+
out_keys_inv=["hashing"],
2533+
repertoire=repertoire,
2534+
)
2535+
inputs = [
2536+
TensorDict({"observation": "test string"}),
2537+
TensorDict({"observation": torch.randn(10)}),
2538+
TensorDict({"observation": "another string"}),
2539+
TensorDict({"observation": torch.randn(3, 2, 1, 8)}),
2540+
]
2541+
outputs = [t(input.clone()).exclude("observation") for input in inputs]
2542+
2543+
# Run the inputs through again, just to make sure that using the same
2544+
# inputs doesn't overwrite the repertoire.
2545+
for input in inputs:
2546+
t(input.clone())
2547+
2548+
assert len(t._repertoire) == 4
2549+
2550+
inv_inputs = [t.inv(output.clone()) for output in outputs]
2551+
2552+
for input, inv_input in zip(inputs, inv_inputs):
2553+
if torch.is_tensor(input["observation"]):
2554+
assert (input["observation"] == inv_input["observation"]).all()
2555+
else:
2556+
assert input["observation"] == inv_input["observation"]
2557+
2558+
@pytest.mark.parametrize("repertoire_gen", [lambda: None, lambda: {}])
2559+
def test_repertoire(self, repertoire_gen):
2560+
repertoire = repertoire_gen()
2561+
t = Hash(in_keys=["observation"], out_keys=["hashing"], repertoire=repertoire)
2562+
inputs = [
2563+
"string",
2564+
["a", "b"],
2565+
torch.randn(3, 4, 1),
2566+
torch.randn(()),
2567+
torch.randn(0),
2568+
1234,
2569+
[1, 2, 3, 4],
2570+
]
2571+
outputs = []
2572+
2573+
for input in inputs:
2574+
td = TensorDict({"observation": input})
2575+
outputs.append(t(td.clone()).clone()["hashing"])
2576+
2577+
for output, input in zip(outputs, inputs):
2578+
if repertoire is not None:
2579+
stored_input = repertoire[t.hash_to_repertoire_key(output)]
2580+
assert stored_input is t.get_input_from_hash(output)
2581+
2582+
if torch.is_tensor(stored_input):
2583+
assert (stored_input == torch.as_tensor(input)).all()
2584+
elif isinstance(stored_input, np.ndarray):
2585+
assert (stored_input == np.asarray(input)).all()
2586+
2587+
else:
2588+
assert stored_input == input
2589+
else:
2590+
with pytest.raises(RuntimeError):
2591+
stored_input = t.get_input_from_hash(output)
25342592

25352593

25362594
@pytest.mark.skipif(

torchrl/envs/transforms/transforms.py

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4826,25 +4826,25 @@ class Hash(UnaryTransform):
48264826
in_keys (sequence of NestedKey): the keys of the values to hash.
48274827
out_keys (sequence of NestedKey): the keys of the resulting hashes.
48284828
in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
4829-
4830-
.. note:: If an inverse map is required, a repertoire ``Dict[Tuple[int], Any]`` of hash to value should be
4831-
passed alongside the list of keys to let the ``Hash`` transform know how to recover a value from a
4832-
given hash. This repertoire isn't copied, so it can be modified in the same workspace after the
4833-
transform instantiation and these modifications will be reflected in the map. Missing hashes will be
4834-
mapped to ``None``.
4835-
48364829
out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
48374830
48384831
Keyword Args:
4839-
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
4840-
the hash function must accept it as its second argument. Default is
4841-
``Hash.reproducible_hash``.
4832+
hash_fn (Callable, optional): the hash function to use. The function
4833+
signature must be
4834+
``(input: Any, seed: Any | None) -> torch.Tensor``.
4835+
``seed`` is only used if this transform is initialized with the
4836+
``seed`` argument. Default is ``Hash.reproducible_hash``.
48424837
seed (optional): seed to use for the hash function, if it requires one.
48434838
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
48444839
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
48454840
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
48464841
inputs are given directly to ``fn``, which must support those
48474842
inputs. Default is ``False``.
4843+
repertoire (Dict[Tuple[int], Any], optional): If given, this dict stores
4844+
the inverse mappings from hashes to inputs. This repertoire isn't
4845+
copied, so it can be modified in the same workspace after the
4846+
transform instantiation and these modifications will be reflected in
4847+
the map. Missing hashes will be mapped to ``None``. Default: ``None``
48484848
48494849
>>> from torchrl.envs import GymEnv, UnaryTransform, Hash
48504850
>>> env = GymEnv("Pendulum-v1")
@@ -4925,57 +4925,79 @@ def __init__(
49254925
self,
49264926
in_keys: Sequence[NestedKey],
49274927
out_keys: Sequence[NestedKey],
4928+
in_keys_inv: Sequence[NestedKey] = None,
4929+
out_keys_inv: Sequence[NestedKey] = None,
49284930
*,
49294931
hash_fn: Callable = None,
49304932
seed: Any | None = None,
49314933
use_raw_nontensor: bool = False,
4934+
repertoire: Tuple[Tuple[int], Any] = None,
49324935
):
49334936
if hash_fn is None:
49344937
hash_fn = Hash.reproducible_hash
49354938

4939+
if repertoire is None and in_keys_inv is not None and len(in_keys_inv) > 0:
4940+
self._repertoire = {}
4941+
else:
4942+
self._repertoire = repertoire
4943+
49364944
self._seed = seed
49374945
self._hash_fn = hash_fn
49384946
super().__init__(
49394947
in_keys=in_keys,
49404948
out_keys=out_keys,
4949+
in_keys_inv=in_keys_inv,
4950+
out_keys_inv=out_keys_inv,
49414951
fn=self.call_hash_fn,
4952+
inv_fn=self.get_input_from_hash,
49424953
use_raw_nontensor=use_raw_nontensor,
49434954
)
49444955

4945-
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
4946-
inputs = tensordict.select(*self.in_keys_inv).detach().cpu()
4947-
tensordict = super()._inv_call(tensordict)
4948-
4949-
def register_outcome(td):
4950-
# We need to treat each hash independently
4951-
if td.ndim:
4952-
if td.ndim > 1:
4953-
td_r = td.reshape(-1)
4954-
elif td.ndim == 1:
4955-
td_r = td
4956-
result = torch.stack([register_outcome(_td) for _td in td_r.unbind(0)])
4957-
if td_r is not td:
4958-
return result.reshape(td.shape)
4959-
return result
4960-
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
4961-
inp = inputs.get(in_key)
4962-
inp = tuple(inp.tolist())
4963-
outp = self._repertoire.get(inp)
4964-
td[out_key] = outp
4965-
return td
4966-
4967-
return register_outcome(tensordict)
4968-
49694956
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
4970-
if self.in_keys_inv is not None:
4971-
return {"_repertoire": self._repertoire}
4972-
return {}
4957+
return {"_repertoire": self._repertoire}
4958+
4959+
@classmethod
4960+
def hash_to_repertoire_key(cls, hash_tensor):
4961+
if isinstance(hash_tensor, torch.Tensor):
4962+
if hash_tensor.dim() == 0:
4963+
return hash_tensor.tolist()
4964+
return tuple(cls.hash_to_repertoire_key(t) for t in hash_tensor.tolist())
4965+
elif isinstance(hash_tensor, list):
4966+
return tuple(cls.hash_to_repertoire_key(t) for t in hash_tensor)
4967+
else:
4968+
return hash_tensor
4969+
4970+
def get_input_from_hash(self, hash_tensor):
4971+
"""Look up the input that was given for a particular hash output.
4972+
4973+
This feature is only available if, during initialization, either the
4974+
:arg:`repertoire` argument was given or both the :arg:`in_keys_inv` and
4975+
:arg:`out_keys_inv` arguments were given.
4976+
4977+
Args:
4978+
hash_tensor (Tensor): The hash output.
4979+
4980+
Returns:
4981+
Any: The input that the hash was generated from.
4982+
"""
4983+
if self._repertoire is None:
4984+
raise RuntimeError(
4985+
"An inverse transform was queried but the repertoire is None."
4986+
)
4987+
return self._repertoire[self.hash_to_repertoire_key(hash_tensor)]
49734988

49744989
def call_hash_fn(self, value):
49754990
if self._seed is None:
4976-
return self._hash_fn(value)
4991+
hash_tensor = self._hash_fn(value)
49774992
else:
4978-
return self._hash_fn(value, self._seed)
4993+
hash_tensor = self._hash_fn(value, self._seed)
4994+
if not torch.is_tensor(hash_tensor):
4995+
raise ValueError(
4996+
f"Hash function must return a tensor, but got {type(hash_tensor)}"
4997+
)
4998+
if self._repertoire is not None:
4999+
self._repertoire[self.hash_to_repertoire_key(hash_tensor)] = copy(value)
5000+
return hash_tensor
49795001

49805002
@classmethod
49815003
def reproducible_hash(cls, string, seed=None):

0 commit comments

Comments
 (0)