@@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec:
795
795
input_spec = self .__dict__ .get ("_input_spec" , None )
796
796
return input_spec
797
797
798
+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDict :
799
+ if self .base_env .rand_action is not EnvBase .rand_action :
800
+ # TODO: this will fail if the transform modifies the input.
801
+ # For instance, if PendulumEnv overrides rand_action and we build a
802
+ # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803
+ # env.rand_action will NOT have a discrete action!
804
+ # Getting a discrete action would require coding the inverse transform of an action within
805
+ # ActionDiscretizer (ie, float->int, not int->float).
806
+ return self .base_env .rand_action (tensordict )
807
+ return super ().rand_action (tensordict )
808
+
798
809
def _step (self , tensordict : TensorDictBase ) -> TensorDictBase :
799
810
# No need to clone here because inv does it already
800
811
# tensordict = tensordict.clone(False)
@@ -4415,10 +4426,12 @@ class UnaryTransform(Transform):
4415
4426
Args:
4416
4427
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
4417
4428
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4418
- fn (Callable ): the function to use as the unary operation. If it accepts
4419
- a non-tensor input, it must also accept ``None`` .
4429
+ in_keys_inv (sequence of NestedKey, optional ): the keys of inputs to the unary operation during inverse call.
4430
+ out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call .
4420
4431
4421
4432
Keyword Args:
4433
+ fn (Callable): the function to use as the unary operation. If it accepts
4434
+ a non-tensor input, it must also accept ``None``.
4422
4435
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4423
4436
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
4424
4437
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4489,11 +4502,18 @@ def __init__(
4489
4502
self ,
4490
4503
in_keys : Sequence [NestedKey ],
4491
4504
out_keys : Sequence [NestedKey ],
4492
- fn : Callable ,
4505
+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4506
+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4493
4507
* ,
4508
+ fn : Callable ,
4494
4509
use_raw_nontensor : bool = False ,
4495
4510
):
4496
- super ().__init__ (in_keys = in_keys , out_keys = out_keys )
4511
+ super ().__init__ (
4512
+ in_keys = in_keys ,
4513
+ out_keys = out_keys ,
4514
+ in_keys_inv = in_keys_inv ,
4515
+ out_keys_inv = out_keys_inv ,
4516
+ )
4497
4517
self ._fn = fn
4498
4518
self ._use_raw_nontensor = use_raw_nontensor
4499
4519
@@ -4508,13 +4528,49 @@ def _apply_transform(self, value):
4508
4528
value = value .tolist ()
4509
4529
return self ._fn (value )
4510
4530
4531
+ def _inv_apply_transform (self , state : torch .Tensor ) -> torch .Tensor :
4532
+ if not self ._use_raw_nontensor :
4533
+ if isinstance (state , NonTensorData ):
4534
+ if state .dim () == 0 :
4535
+ state = state .get ("data" )
4536
+ else :
4537
+ state = state .tolist ()
4538
+ elif isinstance (state , NonTensorStack ):
4539
+ state = state .tolist ()
4540
+ return self ._fn (state )
4541
+
4511
4542
def _reset (
4512
4543
self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
4513
4544
) -> TensorDictBase :
4514
4545
with _set_missing_tolerance (self , True ):
4515
4546
tensordict_reset = self ._call (tensordict_reset )
4516
4547
return tensordict_reset
4517
4548
4549
+ def transform_input_spec (self , input_spec : Composite ) -> Composite :
4550
+ input_spec = input_spec .clone ()
4551
+
4552
+ # Make a generic input from the spec, call the transform with that
4553
+ # input, and then generate the output spec from the output.
4554
+ zero_input_ = input_spec .zero ()
4555
+ test_input = zero_input_ ["full_action_spec" ].update (
4556
+ zero_input_ ["full_state_spec" ]
4557
+ )
4558
+ test_output = self .inv (test_input )
4559
+ test_input_spec = make_composite_from_td (
4560
+ test_output , unsqueeze_null_shapes = False
4561
+ )
4562
+
4563
+ input_spec ["full_action_spec" ] = self .transform_action_spec (
4564
+ input_spec ["full_action_spec" ],
4565
+ test_input_spec ,
4566
+ )
4567
+ if "full_state_spec" in input_spec .keys ():
4568
+ input_spec ["full_state_spec" ] = self .transform_state_spec (
4569
+ input_spec ["full_state_spec" ],
4570
+ test_input_spec ,
4571
+ )
4572
+ return input_spec
4573
+
4518
4574
def transform_output_spec (self , output_spec : Composite ) -> Composite :
4519
4575
output_spec = output_spec .clone ()
4520
4576
@@ -4575,19 +4631,31 @@ def transform_done_spec(
4575
4631
) -> TensorSpec :
4576
4632
return self ._transform_spec (done_spec , test_output_spec )
4577
4633
4634
+ def transform_action_spec (
4635
+ self , action_spec : TensorSpec , test_input_spec : TensorSpec
4636
+ ) -> TensorSpec :
4637
+ return self ._transform_spec (action_spec , test_input_spec )
4638
+
4639
+ def transform_state_spec (
4640
+ self , state_spec : TensorSpec , test_input_spec : TensorSpec
4641
+ ) -> TensorSpec :
4642
+ return self ._transform_spec (state_spec , test_input_spec )
4643
+
4578
4644
4579
4645
class Hash (UnaryTransform ):
4580
4646
r"""Adds a hash value to a tensordict.
4581
4647
4582
4648
Args:
4583
4649
in_keys (sequence of NestedKey): the keys of the values to hash.
4584
4650
out_keys (sequence of NestedKey): the keys of the resulting hashes.
4651
+ in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
4652
+ out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
4653
+
4654
+ Keyword Args:
4585
4655
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
4586
4656
the hash function must accept it as its second argument. Default is
4587
4657
``Hash.reproducible_hash``.
4588
4658
seed (optional): seed to use for the hash function, if it requires one.
4589
-
4590
- Keyword Args:
4591
4659
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4592
4660
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
4593
4661
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4673,9 +4741,11 @@ def __init__(
4673
4741
self ,
4674
4742
in_keys : Sequence [NestedKey ],
4675
4743
out_keys : Sequence [NestedKey ],
4744
+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4745
+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4746
+ * ,
4676
4747
hash_fn : Callable = None ,
4677
4748
seed : Any | None = None ,
4678
- * ,
4679
4749
use_raw_nontensor : bool = False ,
4680
4750
):
4681
4751
if hash_fn is None :
@@ -4686,6 +4756,8 @@ def __init__(
4686
4756
super ().__init__ (
4687
4757
in_keys = in_keys ,
4688
4758
out_keys = out_keys ,
4759
+ in_keys_inv = in_keys_inv ,
4760
+ out_keys_inv = out_keys_inv ,
4689
4761
fn = self .call_hash_fn ,
4690
4762
use_raw_nontensor = use_raw_nontensor ,
4691
4763
)
@@ -4714,7 +4786,7 @@ def reproducible_hash(cls, string, seed=None):
4714
4786
if seed is not None :
4715
4787
seeded_string = seed + string
4716
4788
else :
4717
- seeded_string = string
4789
+ seeded_string = str ( string )
4718
4790
4719
4791
# Create a new SHA-256 hash object
4720
4792
hash_object = hashlib .sha256 ()
@@ -4728,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
4728
4800
return torch .frombuffer (hash_bytes , dtype = torch .uint8 )
4729
4801
4730
4802
4803
+ class Tokenizer (UnaryTransform ):
4804
+ r"""Applies a tokenization operation on the specified inputs.
4805
+
4806
+ Args:
4807
+ in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
4808
+ out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
4809
+ in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
4810
+ out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
4811
+
4812
+ Keyword Args:
4813
+ tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
4814
+ "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
4815
+ pre-trained tokenizer.
4816
+ use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4817
+ :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
4818
+ function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
4819
+ inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
4820
+ additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
4821
+ """
4822
+
4823
+ def __init__ (
4824
+ self ,
4825
+ in_keys : Sequence [NestedKey ],
4826
+ out_keys : Sequence [NestedKey ],
4827
+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4828
+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4829
+ * ,
4830
+ tokenizer : "transformers.PretrainedTokenizerBase" = None , # noqa: F821
4831
+ use_raw_nontensor : bool = False ,
4832
+ additional_tokens : List [str ] | None = None ,
4833
+ ):
4834
+ if tokenizer is None :
4835
+ from transformers import AutoTokenizer
4836
+
4837
+ tokenizer = AutoTokenizer .from_pretrained ("bert-base-uncased" )
4838
+ elif isinstance (tokenizer , str ):
4839
+ from transformers import AutoTokenizer
4840
+
4841
+ tokenizer = AutoTokenizer .from_pretrained (tokenizer )
4842
+
4843
+ self .tokenizer = tokenizer
4844
+ if additional_tokens :
4845
+ self .tokenizer .add_tokens (additional_tokens )
4846
+ super ().__init__ (
4847
+ in_keys = in_keys ,
4848
+ out_keys = out_keys ,
4849
+ in_keys_inv = in_keys_inv ,
4850
+ out_keys_inv = out_keys_inv ,
4851
+ fn = self .call_tokenizer_fn ,
4852
+ use_raw_nontensor = use_raw_nontensor ,
4853
+ )
4854
+
4855
+ @property
4856
+ def device (self ):
4857
+ if "_device" in self .__dict__ :
4858
+ return self ._device
4859
+ parent = self .parent
4860
+ if parent is None :
4861
+ return None
4862
+ device = parent .device
4863
+ self ._device = device
4864
+ return device
4865
+
4866
+ def call_tokenizer_fn (self , value : str | List [str ]):
4867
+ device = self .device
4868
+ out = self .tokenizer .encode (value , return_tensors = "pt" )
4869
+ if device is not None and out .device != device :
4870
+ out = out .to (device )
4871
+ return out
4872
+
4873
+
4731
4874
class Stack (Transform ):
4732
4875
"""Stacks tensors and tensordicts.
4733
4876
0 commit comments