@@ -4426,10 +4426,12 @@ class UnaryTransform(Transform):
4426
4426
Args:
4427
4427
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
4428
4428
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4429
- fn (Callable ): the function to use as the unary operation. If it accepts
4430
- a non-tensor input, it must also accept ``None`` .
4429
+ in_keys_inv (sequence of NestedKey ): the keys of inputs to the unary operation during inverse call.
4430
+ out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call .
4431
4431
4432
4432
Keyword Args:
4433
+ fn (Callable): the function to use as the unary operation. If it accepts
4434
+ a non-tensor input, it must also accept ``None``.
4433
4435
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4434
4436
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
4435
4437
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4500,11 +4502,18 @@ def __init__(
4500
4502
self ,
4501
4503
in_keys : Sequence [NestedKey ],
4502
4504
out_keys : Sequence [NestedKey ],
4503
- fn : Callable ,
4505
+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4506
+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4504
4507
* ,
4508
+ fn : Callable ,
4505
4509
use_raw_nontensor : bool = False ,
4506
4510
):
4507
- super ().__init__ (in_keys = in_keys , out_keys = out_keys )
4511
+ super ().__init__ (
4512
+ in_keys = in_keys ,
4513
+ out_keys = out_keys ,
4514
+ in_keys_inv = in_keys_inv ,
4515
+ out_keys_inv = out_keys_inv ,
4516
+ )
4508
4517
self ._fn = fn
4509
4518
self ._use_raw_nontensor = use_raw_nontensor
4510
4519
@@ -4519,13 +4528,50 @@ def _apply_transform(self, value):
4519
4528
value = value .tolist ()
4520
4529
return self ._fn (value )
4521
4530
4531
+ def _inv_apply_transform (self , state : torch .Tensor ) -> torch .Tensor :
4532
+ if not self ._use_raw_nontensor :
4533
+ if isinstance (state , NonTensorData ):
4534
+ if state .dim () == 0 :
4535
+ state = state .get ("data" )
4536
+ else :
4537
+ state = state .tolist ()
4538
+ elif isinstance (state , NonTensorStack ):
4539
+ state = state .tolist ()
4540
+ return self ._fn (state )
4541
+
4522
4542
def _reset (
4523
4543
self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
4524
4544
) -> TensorDictBase :
4525
4545
with _set_missing_tolerance (self , True ):
4526
4546
tensordict_reset = self ._call (tensordict_reset )
4527
4547
return tensordict_reset
4528
4548
4549
+ def transform_input_spec (self , input_spec : Composite ) -> Composite :
4550
+ input_spec = input_spec .clone ()
4551
+
4552
+ # Make a generic input from the spec, call the transform with that
4553
+ # input, and then generate the output spec from the output.
4554
+ zero_input_ = input_spec .zero ()
4555
+ test_input = zero_input_ ["full_action_spec" ].update (
4556
+ zero_input_ ["full_state_spec" ]
4557
+ )
4558
+ test_output = self .inv (test_input )
4559
+ test_input_spec = make_composite_from_td (
4560
+ test_output , unsqueeze_null_shapes = False
4561
+ )
4562
+
4563
+ input_spec ["full_action_spec" ] = self .transform_action_spec (
4564
+ input_spec ["full_action_spec" ],
4565
+ test_input_spec ,
4566
+ )
4567
+ if "full_state_spec" in input_spec .keys ():
4568
+ input_spec ["full_state_spec" ] = self .transform_state_spec (
4569
+ input_spec ["full_state_spec" ],
4570
+ test_input_spec ,
4571
+ )
4572
+ print (input_spec )
4573
+ return input_spec
4574
+
4529
4575
def transform_output_spec (self , output_spec : Composite ) -> Composite :
4530
4576
output_spec = output_spec .clone ()
4531
4577
@@ -4586,19 +4632,31 @@ def transform_done_spec(
4586
4632
) -> TensorSpec :
4587
4633
return self ._transform_spec (done_spec , test_output_spec )
4588
4634
4635
+ def transform_action_spec (
4636
+ self , action_spec : TensorSpec , test_input_spec : TensorSpec
4637
+ ) -> TensorSpec :
4638
+ return self ._transform_spec (action_spec , test_input_spec )
4639
+
4640
+ def transform_state_spec (
4641
+ self , state_spec : TensorSpec , test_input_spec : TensorSpec
4642
+ ) -> TensorSpec :
4643
+ return self ._transform_spec (state_spec , test_input_spec )
4644
+
4589
4645
4590
4646
class Hash (UnaryTransform ):
4591
4647
r"""Adds a hash value to a tensordict.
4592
4648
4593
4649
Args:
4594
4650
in_keys (sequence of NestedKey): the keys of the values to hash.
4595
4651
out_keys (sequence of NestedKey): the keys of the resulting hashes.
4652
+ in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
4653
+ out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
4654
+
4655
+ Keyword Args:
4596
4656
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
4597
4657
the hash function must accept it as its second argument. Default is
4598
4658
``Hash.reproducible_hash``.
4599
4659
seed (optional): seed to use for the hash function, if it requires one.
4600
-
4601
- Keyword Args:
4602
4660
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4603
4661
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
4604
4662
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4684,9 +4742,11 @@ def __init__(
4684
4742
self ,
4685
4743
in_keys : Sequence [NestedKey ],
4686
4744
out_keys : Sequence [NestedKey ],
4745
+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4746
+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4747
+ * ,
4687
4748
hash_fn : Callable = None ,
4688
4749
seed : Any | None = None ,
4689
- * ,
4690
4750
use_raw_nontensor : bool = False ,
4691
4751
):
4692
4752
if hash_fn is None :
@@ -4697,6 +4757,8 @@ def __init__(
4697
4757
super ().__init__ (
4698
4758
in_keys = in_keys ,
4699
4759
out_keys = out_keys ,
4760
+ in_keys_inv = in_keys_inv ,
4761
+ out_keys_inv = out_keys_inv ,
4700
4762
fn = self .call_hash_fn ,
4701
4763
use_raw_nontensor = use_raw_nontensor ,
4702
4764
)
@@ -4725,7 +4787,7 @@ def reproducible_hash(cls, string, seed=None):
4725
4787
if seed is not None :
4726
4788
seeded_string = seed + string
4727
4789
else :
4728
- seeded_string = string
4790
+ seeded_string = str ( string )
4729
4791
4730
4792
# Create a new SHA-256 hash object
4731
4793
hash_object = hashlib .sha256 ()
0 commit comments