Skip to content

Commit 6633611

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 256a700 commit 6633611

File tree

4 files changed

+193
-16
lines changed

4 files changed

+193
-16
lines changed

torchrl/data/tensor_specs.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,6 +2457,7 @@ def __init__(
24572457
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
24582458
device: Optional[DEVICE_TYPING] = None,
24592459
dtype: torch.dtype | None = None,
2460+
example_data: Any = None,
24602461
**kwargs,
24612462
):
24622463
if isinstance(shape, int):
@@ -2467,6 +2468,7 @@ def __init__(
24672468
super().__init__(
24682469
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
24692470
)
2471+
self.example_data = example_data
24702472

24712473
def cardinality(self) -> Any:
24722474
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
@@ -2485,30 +2487,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24852487
dest_device = torch.device(dest)
24862488
if dest_device == self.device and dest_dtype == self.dtype:
24872489
return self
2488-
return self.__class__(shape=self.shape, device=dest_device, dtype=None)
2490+
return self.__class__(
2491+
shape=self.shape,
2492+
device=dest_device,
2493+
dtype=None,
2494+
example_data=self.example_data,
2495+
)
24892496

24902497
def clone(self) -> NonTensor:
2491-
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
2498+
return self.__class__(
2499+
shape=self.shape,
2500+
device=self.device,
2501+
dtype=self.dtype,
2502+
example_data=self.example_data,
2503+
)
24922504

24932505
def rand(self, shape=None):
24942506
if shape is None:
24952507
shape = ()
24962508
return NonTensorData(
2497-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2509+
data=self.example_data,
2510+
batch_size=(*shape, *self._safe_shape),
2511+
device=self.device,
24982512
)
24992513

25002514
def zero(self, shape=None):
25012515
if shape is None:
25022516
shape = ()
25032517
return NonTensorData(
2504-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2518+
data=self.example_data,
2519+
batch_size=(*shape, *self._safe_shape),
2520+
device=self.device,
25052521
)
25062522

25072523
def one(self, shape=None):
25082524
if shape is None:
25092525
shape = ()
25102526
return NonTensorData(
2511-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2527+
data=self.example_data,
2528+
batch_size=(*shape, *self._safe_shape),
2529+
device=self.device,
25122530
)
25132531

25142532
def is_in(self, val: Any) -> bool:
@@ -2533,23 +2551,36 @@ def expand(self, *shape):
25332551
raise ValueError(
25342552
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
25352553
)
2536-
return self.__class__(shape=shape, device=self.device, dtype=None)
2554+
return self.__class__(
2555+
shape=shape, device=self.device, dtype=None, example_data=self.example_data
2556+
)
25372557

25382558
def _reshape(self, shape):
2539-
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
2559+
return self.__class__(
2560+
shape=shape,
2561+
device=self.device,
2562+
dtype=self.dtype,
2563+
example_data=self.example_data,
2564+
)
25402565

25412566
def _unflatten(self, dim, sizes):
25422567
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
25432568
return self.__class__(
25442569
shape=shape,
25452570
device=self.device,
25462571
dtype=self.dtype,
2572+
example_data=self.example_data,
25472573
)
25482574

25492575
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
25502576
"""Indexes the current TensorSpec based on the provided index."""
25512577
indexed_shape = _size(_shape_indexing(self.shape, idx))
2552-
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
2578+
return self.__class__(
2579+
shape=indexed_shape,
2580+
device=self.device,
2581+
dtype=self.dtype,
2582+
example_data=self.example_data,
2583+
)
25532584

25542585
def unbind(self, dim: int = 0):
25552586
orig_dim = dim
@@ -2565,6 +2596,7 @@ def unbind(self, dim: int = 0):
25652596
shape=shape,
25662597
device=self.device,
25672598
dtype=self.dtype,
2599+
example_data=self.example_data,
25682600
)
25692601
for i in range(self.shape[dim])
25702602
)

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
TargetReturn,
9595
TensorDictPrimer,
9696
TimeMaxPool,
97+
Tokenizer,
9798
ToTensorImage,
9899
TrajCounter,
99100
Transform,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TargetReturn,
5656
TensorDictPrimer,
5757
TimeMaxPool,
58+
Tokenizer,
5859
ToTensorImage,
5960
TrajCounter,
6061
Transform,

torchrl/envs/transforms/transforms.py

Lines changed: 151 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec:
795795
input_spec = self.__dict__.get("_input_spec", None)
796796
return input_spec
797797

798+
def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict:
799+
if self.base_env.rand_action is not EnvBase.rand_action:
800+
# TODO: this will fail if the transform modifies the input.
801+
# For instance, if PendulumEnv overrides rand_action and we build a
802+
# env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803+
# env.rand_action will NOT have a discrete action!
804+
# Getting a discrete action would require coding the inverse transform of an action within
805+
# ActionDiscretizer (ie, float->int, not int->float).
806+
return self.base_env.rand_action(tensordict)
807+
return super().rand_action(tensordict)
808+
798809
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
799810
# No need to clone here because inv does it already
800811
# tensordict = tensordict.clone(False)
@@ -4415,10 +4426,12 @@ class UnaryTransform(Transform):
44154426
Args:
44164427
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
44174428
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4418-
fn (Callable): the function to use as the unary operation. If it accepts
4419-
a non-tensor input, it must also accept ``None``.
4429+
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call.
4430+
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call.
44204431
44214432
Keyword Args:
4433+
fn (Callable): the function to use as the unary operation. If it accepts
4434+
a non-tensor input, it must also accept ``None``.
44224435
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
44234436
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
44244437
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4489,11 +4502,18 @@ def __init__(
44894502
self,
44904503
in_keys: Sequence[NestedKey],
44914504
out_keys: Sequence[NestedKey],
4492-
fn: Callable,
4505+
in_keys_inv: Sequence[NestedKey] | None = None,
4506+
out_keys_inv: Sequence[NestedKey] | None = None,
44934507
*,
4508+
fn: Callable,
44944509
use_raw_nontensor: bool = False,
44954510
):
4496-
super().__init__(in_keys=in_keys, out_keys=out_keys)
4511+
super().__init__(
4512+
in_keys=in_keys,
4513+
out_keys=out_keys,
4514+
in_keys_inv=in_keys_inv,
4515+
out_keys_inv=out_keys_inv,
4516+
)
44974517
self._fn = fn
44984518
self._use_raw_nontensor = use_raw_nontensor
44994519

@@ -4508,13 +4528,49 @@ def _apply_transform(self, value):
45084528
value = value.tolist()
45094529
return self._fn(value)
45104530

4531+
def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
4532+
if not self._use_raw_nontensor:
4533+
if isinstance(state, NonTensorData):
4534+
if state.dim() == 0:
4535+
state = state.get("data")
4536+
else:
4537+
state = state.tolist()
4538+
elif isinstance(state, NonTensorStack):
4539+
state = state.tolist()
4540+
return self._fn(state)
4541+
45114542
def _reset(
45124543
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
45134544
) -> TensorDictBase:
45144545
with _set_missing_tolerance(self, True):
45154546
tensordict_reset = self._call(tensordict_reset)
45164547
return tensordict_reset
45174548

4549+
def transform_input_spec(self, input_spec: Composite) -> Composite:
4550+
input_spec = input_spec.clone()
4551+
4552+
# Make a generic input from the spec, call the transform with that
4553+
# input, and then generate the output spec from the output.
4554+
zero_input_ = input_spec.zero()
4555+
test_input = zero_input_["full_action_spec"].update(
4556+
zero_input_["full_state_spec"]
4557+
)
4558+
test_output = self.inv(test_input)
4559+
test_input_spec = make_composite_from_td(
4560+
test_output, unsqueeze_null_shapes=False
4561+
)
4562+
4563+
input_spec["full_action_spec"] = self.transform_action_spec(
4564+
input_spec["full_action_spec"],
4565+
test_input_spec,
4566+
)
4567+
if "full_state_spec" in input_spec.keys():
4568+
input_spec["full_state_spec"] = self.transform_state_spec(
4569+
input_spec["full_state_spec"],
4570+
test_input_spec,
4571+
)
4572+
return input_spec
4573+
45184574
def transform_output_spec(self, output_spec: Composite) -> Composite:
45194575
output_spec = output_spec.clone()
45204576

@@ -4575,19 +4631,31 @@ def transform_done_spec(
45754631
) -> TensorSpec:
45764632
return self._transform_spec(done_spec, test_output_spec)
45774633

4634+
def transform_action_spec(
4635+
self, action_spec: TensorSpec, test_input_spec: TensorSpec
4636+
) -> TensorSpec:
4637+
return self._transform_spec(action_spec, test_input_spec)
4638+
4639+
def transform_state_spec(
4640+
self, state_spec: TensorSpec, test_input_spec: TensorSpec
4641+
) -> TensorSpec:
4642+
return self._transform_spec(state_spec, test_input_spec)
4643+
45784644

45794645
class Hash(UnaryTransform):
45804646
r"""Adds a hash value to a tensordict.
45814647
45824648
Args:
45834649
in_keys (sequence of NestedKey): the keys of the values to hash.
45844650
out_keys (sequence of NestedKey): the keys of the resulting hashes.
4651+
in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
4652+
out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
4653+
4654+
Keyword Args:
45854655
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
45864656
the hash function must accept it as its second argument. Default is
45874657
``Hash.reproducible_hash``.
45884658
seed (optional): seed to use for the hash function, if it requires one.
4589-
4590-
Keyword Args:
45914659
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
45924660
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
45934661
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4673,9 +4741,11 @@ def __init__(
46734741
self,
46744742
in_keys: Sequence[NestedKey],
46754743
out_keys: Sequence[NestedKey],
4744+
in_keys_inv: Sequence[NestedKey] | None = None,
4745+
out_keys_inv: Sequence[NestedKey] | None = None,
4746+
*,
46764747
hash_fn: Callable = None,
46774748
seed: Any | None = None,
4678-
*,
46794749
use_raw_nontensor: bool = False,
46804750
):
46814751
if hash_fn is None:
@@ -4686,6 +4756,8 @@ def __init__(
46864756
super().__init__(
46874757
in_keys=in_keys,
46884758
out_keys=out_keys,
4759+
in_keys_inv=in_keys_inv,
4760+
out_keys_inv=out_keys_inv,
46894761
fn=self.call_hash_fn,
46904762
use_raw_nontensor=use_raw_nontensor,
46914763
)
@@ -4714,7 +4786,7 @@ def reproducible_hash(cls, string, seed=None):
47144786
if seed is not None:
47154787
seeded_string = seed + string
47164788
else:
4717-
seeded_string = string
4789+
seeded_string = str(string)
47184790

47194791
# Create a new SHA-256 hash object
47204792
hash_object = hashlib.sha256()
@@ -4728,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
47284800
return torch.frombuffer(hash_bytes, dtype=torch.uint8)
47294801

47304802

4803+
class Tokenizer(UnaryTransform):
4804+
r"""Applies a tokenization operation on the specified inputs.
4805+
4806+
Args:
4807+
in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
4808+
out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
4809+
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
4810+
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
4811+
4812+
Keyword Args:
4813+
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
4814+
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
4815+
pre-trained tokenizer.
4816+
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4817+
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
4818+
function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
4819+
inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
4820+
additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
4821+
"""
4822+
4823+
def __init__(
4824+
self,
4825+
in_keys: Sequence[NestedKey],
4826+
out_keys: Sequence[NestedKey],
4827+
in_keys_inv: Sequence[NestedKey] | None = None,
4828+
out_keys_inv: Sequence[NestedKey] | None = None,
4829+
*,
4830+
tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821
4831+
use_raw_nontensor: bool = False,
4832+
additional_tokens: List[str] | None = None,
4833+
):
4834+
if tokenizer is None:
4835+
from transformers import AutoTokenizer
4836+
4837+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
4838+
elif isinstance(tokenizer, str):
4839+
from transformers import AutoTokenizer
4840+
4841+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
4842+
4843+
self.tokenizer = tokenizer
4844+
if additional_tokens:
4845+
self.tokenizer.add_tokens(additional_tokens)
4846+
super().__init__(
4847+
in_keys=in_keys,
4848+
out_keys=out_keys,
4849+
in_keys_inv=in_keys_inv,
4850+
out_keys_inv=out_keys_inv,
4851+
fn=self.call_tokenizer_fn,
4852+
use_raw_nontensor=use_raw_nontensor,
4853+
)
4854+
4855+
@property
4856+
def device(self):
4857+
if "_device" in self.__dict__:
4858+
return self._device
4859+
parent = self.parent
4860+
if parent is None:
4861+
return None
4862+
device = parent.device
4863+
self._device = device
4864+
return device
4865+
4866+
def call_tokenizer_fn(self, value: str | List[str]):
4867+
device = self.device
4868+
out = self.tokenizer.encode(value, return_tensors="pt")
4869+
if device is not None and out.device != device:
4870+
out = out.to(device)
4871+
return out
4872+
4873+
47314874
class Stack(Transform):
47324875
"""Stacks tensors and tensordicts.
47334876

0 commit comments

Comments
 (0)