Skip to content

Commit 9e51c8b

Browse files
author
Vincent Moens
committed
[Feature] UnaryTransform for input entries
ghstack-source-id: 5a41e9c Pull Request resolved: #2700
1 parent 0c74b19 commit 9e51c8b

File tree

1 file changed

+70
-8
lines changed

1 file changed

+70
-8
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4426,10 +4426,12 @@ class UnaryTransform(Transform):
44264426
Args:
44274427
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
44284428
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4429-
fn (Callable): the function to use as the unary operation. If it accepts
4430-
a non-tensor input, it must also accept ``None``.
4429+
in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call.
4430+
out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call.
44314431
44324432
Keyword Args:
4433+
fn (Callable): the function to use as the unary operation. If it accepts
4434+
a non-tensor input, it must also accept ``None``.
44334435
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
44344436
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
44354437
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4500,11 +4502,18 @@ def __init__(
45004502
self,
45014503
in_keys: Sequence[NestedKey],
45024504
out_keys: Sequence[NestedKey],
4503-
fn: Callable,
4505+
in_keys_inv: Sequence[NestedKey] | None = None,
4506+
out_keys_inv: Sequence[NestedKey] | None = None,
45044507
*,
4508+
fn: Callable,
45054509
use_raw_nontensor: bool = False,
45064510
):
4507-
super().__init__(in_keys=in_keys, out_keys=out_keys)
4511+
super().__init__(
4512+
in_keys=in_keys,
4513+
out_keys=out_keys,
4514+
in_keys_inv=in_keys_inv,
4515+
out_keys_inv=out_keys_inv,
4516+
)
45084517
self._fn = fn
45094518
self._use_raw_nontensor = use_raw_nontensor
45104519

@@ -4519,13 +4528,50 @@ def _apply_transform(self, value):
45194528
value = value.tolist()
45204529
return self._fn(value)
45214530

4531+
def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
4532+
if not self._use_raw_nontensor:
4533+
if isinstance(state, NonTensorData):
4534+
if state.dim() == 0:
4535+
state = state.get("data")
4536+
else:
4537+
state = state.tolist()
4538+
elif isinstance(state, NonTensorStack):
4539+
state = state.tolist()
4540+
return self._fn(state)
4541+
45224542
def _reset(
45234543
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
45244544
) -> TensorDictBase:
45254545
with _set_missing_tolerance(self, True):
45264546
tensordict_reset = self._call(tensordict_reset)
45274547
return tensordict_reset
45284548

4549+
def transform_input_spec(self, input_spec: Composite) -> Composite:
4550+
input_spec = input_spec.clone()
4551+
4552+
# Make a generic input from the spec, call the transform with that
4553+
# input, and then generate the output spec from the output.
4554+
zero_input_ = input_spec.zero()
4555+
test_input = zero_input_["full_action_spec"].update(
4556+
zero_input_["full_state_spec"]
4557+
)
4558+
test_output = self.inv(test_input)
4559+
test_input_spec = make_composite_from_td(
4560+
test_output, unsqueeze_null_shapes=False
4561+
)
4562+
4563+
input_spec["full_action_spec"] = self.transform_action_spec(
4564+
input_spec["full_action_spec"],
4565+
test_input_spec,
4566+
)
4567+
if "full_state_spec" in input_spec.keys():
4568+
input_spec["full_state_spec"] = self.transform_state_spec(
4569+
input_spec["full_state_spec"],
4570+
test_input_spec,
4571+
)
4572+
print(input_spec)
4573+
return input_spec
4574+
45294575
def transform_output_spec(self, output_spec: Composite) -> Composite:
45304576
output_spec = output_spec.clone()
45314577

@@ -4586,19 +4632,31 @@ def transform_done_spec(
45864632
) -> TensorSpec:
45874633
return self._transform_spec(done_spec, test_output_spec)
45884634

4635+
def transform_action_spec(
4636+
self, action_spec: TensorSpec, test_input_spec: TensorSpec
4637+
) -> TensorSpec:
4638+
return self._transform_spec(action_spec, test_input_spec)
4639+
4640+
def transform_state_spec(
4641+
self, state_spec: TensorSpec, test_input_spec: TensorSpec
4642+
) -> TensorSpec:
4643+
return self._transform_spec(state_spec, test_input_spec)
4644+
45894645

45904646
class Hash(UnaryTransform):
45914647
r"""Adds a hash value to a tensordict.
45924648
45934649
Args:
45944650
in_keys (sequence of NestedKey): the keys of the values to hash.
45954651
out_keys (sequence of NestedKey): the keys of the resulting hashes.
4652+
in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
4653+
out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
4654+
4655+
Keyword Args:
45964656
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
45974657
the hash function must accept it as its second argument. Default is
45984658
``Hash.reproducible_hash``.
45994659
seed (optional): seed to use for the hash function, if it requires one.
4600-
4601-
Keyword Args:
46024660
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
46034661
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
46044662
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4684,9 +4742,11 @@ def __init__(
46844742
self,
46854743
in_keys: Sequence[NestedKey],
46864744
out_keys: Sequence[NestedKey],
4745+
in_keys_inv: Sequence[NestedKey] | None = None,
4746+
out_keys_inv: Sequence[NestedKey] | None = None,
4747+
*,
46874748
hash_fn: Callable = None,
46884749
seed: Any | None = None,
4689-
*,
46904750
use_raw_nontensor: bool = False,
46914751
):
46924752
if hash_fn is None:
@@ -4697,6 +4757,8 @@ def __init__(
46974757
super().__init__(
46984758
in_keys=in_keys,
46994759
out_keys=out_keys,
4760+
in_keys_inv=in_keys_inv,
4761+
out_keys_inv=out_keys_inv,
47004762
fn=self.call_hash_fn,
47014763
use_raw_nontensor=use_raw_nontensor,
47024764
)
@@ -4725,7 +4787,7 @@ def reproducible_hash(cls, string, seed=None):
47254787
if seed is not None:
47264788
seeded_string = seed + string
47274789
else:
4728-
seeded_string = string
4790+
seeded_string = str(string)
47294791

47304792
# Create a new SHA-256 hash object
47314793
hash_object = hashlib.sha256()

0 commit comments

Comments
 (0)