From e32c4a8e40a633754343c62e663bf417a63fb186 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 1 Jun 2023 21:11:55 +0100 Subject: [PATCH 1/2] init --- docs/source/reference/data.rst | 3 +- docs/source/reference/envs.rst | 6 +- test/smoke_test.py | 2 +- torchrl/collectors/collectors.py | 10 +- torchrl/data/__init__.py | 1 + torchrl/data/tensor_specs.py | 59 ++++---- torchrl/envs/common.py | 30 ++-- torchrl/envs/gym_like.py | 8 +- torchrl/envs/libs/dm_control.py | 4 +- torchrl/envs/libs/gym.py | 4 +- torchrl/envs/libs/jax_utils.py | 6 +- torchrl/envs/libs/jumanji.py | 12 +- torchrl/envs/model_based/common.py | 4 +- torchrl/envs/transforms/r3m.py | 6 +- torchrl/envs/transforms/transforms.py | 136 +++++++++++------- torchrl/envs/transforms/vip.py | 6 +- torchrl/envs/vec_env.py | 14 +- torchrl/modules/models/model_based.py | 2 +- torchrl/modules/tensordict_module/actors.py | 38 ++--- torchrl/modules/tensordict_module/common.py | 12 +- .../modules/tensordict_module/exploration.py | 18 +-- .../tensordict_module/probabilistic.py | 14 +- torchrl/modules/utils/utils.py | 4 +- torchrl/objectives/dqn.py | 6 +- tutorials/sphinx-tutorials/pendulum.py | 2 +- 25 files changed, 230 insertions(+), 177 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 349d4022b1b..38aaa77ae10 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -191,7 +191,7 @@ Here's an example: TensorSpec ---------- -The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such +The `TensorSpecBase` parent class and subclasses define the basic properties of observations and actions in TorchRL, such as shape, device, dtype and domain. It is important that your environment specs match the input and output that it sends and receives, as :obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. @@ -203,6 +203,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. :toctree: generated/ :template: rl_template.rst + TensorSpecBase TensorSpec BinaryDiscreteTensorSpec BoundedTensorSpec diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index b7103d422e0..1d76f60b79c 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -31,11 +31,11 @@ Each env will have the following attributes: - :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object containing all the input key-spec pairs (except action). For most stateful environments, this container will be empty. -- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object +- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpecBase` object representing the action spec. -- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing +- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpecBase` object representing the reward spec. -- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing +- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpecBase` object representing the done-flag spec. - :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing all the input keys (:obj:`"_action_spec"` and :obj:`"_state_spec"`). diff --git a/test/smoke_test.py b/test/smoke_test.py index 313c786088c..f3ff7683bfa 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -8,7 +8,7 @@ def test_imports(): from torchrl.data import ( PrioritizedReplayBuffer, ReplayBuffer, - TensorSpec, + TensorSpecBase, ) # noqa: F401 from torchrl.envs import Transform, TransformedEnv # noqa: F401 from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 6c8d1cc114b..a43982ab005 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -35,7 +35,7 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories -from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.tensor_specs import TensorSpecBase from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.transforms import StepCounter, TransformedEnv @@ -62,7 +62,7 @@ class RandomPolicy: This is a wrapper around the action_spec.rand method. Args: - action_spec: TensorSpec object describing the action specs + action_spec: TensorSpecBase object describing the action specs Examples: >>> from tensordict import TensorDict @@ -72,7 +72,7 @@ class RandomPolicy: >>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1] """ - def __init__(self, action_spec: TensorSpec): + def __init__(self, action_spec: TensorSpecBase): self.action_spec = action_spec def __call__(self, td: TensorDictBase) -> TensorDictBase: @@ -185,7 +185,7 @@ def _get_policy_and_device( ] ] = None, device: Optional[DEVICE_TYPING] = None, - observation_spec: TensorSpec = None, + observation_spec: TensorSpecBase = None, ) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. @@ -200,7 +200,7 @@ def _get_policy_and_device( policy (TensorDictModule, optional): a policy to be used device (int, str or torch.device, optional): device where to place the policy - observation_spec (TensorSpec, optional): spec of the observations + observation_spec (TensorSpecBase, optional): spec of the observations """ if policy is None: diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 14b29501f76..15bc7498b5c 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -31,6 +31,7 @@ MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 2bceec43fcc..dd79fc27c9f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -457,7 +457,7 @@ def __repr__(self): @dataclass(repr=False) -class TensorSpec: +class TensorSpecBase: """Parent class of the tensor meta-data containers for observation, actions and rewards. Args: @@ -700,11 +700,11 @@ def zero(self, shape=None) -> torch.Tensor: return torch.zeros((*shape, *self.shape), dtype=self.dtype, device=self.device) @abc.abstractmethod - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpecBase": raise NotImplementedError @abc.abstractmethod - def clone(self) -> "TensorSpec": + def clone(self) -> "TensorSpecBase": raise NotImplementedError def __repr__(self): @@ -730,7 +730,7 @@ def __torch_function__( if kwargs is None: kwargs = {} if func not in cls.SPEC_HANDLED_FUNCTIONS or not all( - issubclass(t, (TensorSpec,)) for t in types + issubclass(t, (TensorSpecBase,)) for t in types ): return NotImplemented( f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}" @@ -792,7 +792,7 @@ def __getitem__(self, item): f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}." ) out = self._specs[item] - if isinstance(out, TensorSpec): + if isinstance(out, TensorSpecBase): return out return torch.stack(list(out), 0) else: @@ -814,7 +814,7 @@ def __getitem__(self, item): for i, _item in enumerate(item): if i == self.dim: out = self._specs[_item] - if isinstance(out, TensorSpec): + if isinstance(out, TensorSpecBase): return out return torch.stack(list(out), 0) elif isinstance(_item, slice): @@ -831,7 +831,7 @@ def __getitem__(self, item): f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}." ) out = self._specs[item] - if isinstance(out, TensorSpec): + if isinstance(out, TensorSpecBase): return out return torch.stack(list(out), 0) @@ -894,7 +894,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: return torch.stack([spec.to(dest) for spec in self._specs], self.dim) -class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): +class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpecBase], TensorSpecBase): """A lazy representation of a stack of tensor specs. Stacks tensor-specs together along one dimension. @@ -975,7 +975,7 @@ def set(self, name, spec): @dataclass(repr=False) -class OneHotDiscreteTensorSpec(TensorSpec): +class OneHotDiscreteTensorSpec(TensorSpecBase): """A unidimensional, one-hot discrete tensor spec. By default, TorchRL assumes that categorical variables are encoded as @@ -1237,7 +1237,7 @@ def to_categorical_spec(self) -> DiscreteTensorSpec: @dataclass(repr=False) -class BoundedTensorSpec(TensorSpec): +class BoundedTensorSpec(TensorSpecBase): """A bounded continuous tensor spec. Args: @@ -1483,7 +1483,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): @dataclass(repr=False) -class UnboundedContinuousTensorSpec(TensorSpec): +class UnboundedContinuousTensorSpec(TensorSpecBase): """An unbounded continuous tensor spec. Args: @@ -1560,8 +1560,15 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) +TensorSpec = type( + "TensorSpec", + UnboundedContinuousTensorSpec.__bases__, + dict(UnboundedContinuousTensorSpec.__dict__), +) + + @dataclass(repr=False) -class UnboundedDiscreteTensorSpec(TensorSpec): +class UnboundedDiscreteTensorSpec(TensorSpecBase): """An unbounded discrete tensor spec. Args: @@ -1891,7 +1898,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) -class DiscreteTensorSpec(TensorSpec): +class DiscreteTensorSpec(TensorSpecBase): """A discrete tensor spec. An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of @@ -2380,7 +2387,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) -class CompositeSpec(TensorSpec): +class CompositeSpec(TensorSpecBase): """A composition of TensorSpecs. Args: @@ -2891,13 +2898,15 @@ def __eq__(self, other): and self._specs == other._specs ) - def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: + def update( + self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpecBase]] + ) -> None: for key, item in dict_or_spec.items(): if key in self.keys(True) and isinstance(self[key], CompositeSpec): self[key].update(item) continue try: - if isinstance(item, TensorSpec) and item.device != self.device: + if isinstance(item, TensorSpecBase) and item.device != self.device: item = deepcopy(item) if self.device is not None: item = item.to(self.device) @@ -3068,7 +3077,9 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): """ - def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: + def update( + self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpecBase]] + ) -> None: pass def __eq__(self, other): @@ -3167,7 +3178,7 @@ def set(self, name, spec): # for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: -@TensorSpec.implements_for_spec(torch.stack) +@TensorSpecBase.implements_for_spec(torch.stack) def _stack_specs(list_of_spec, dim, out=None): if out is not None: raise NotImplementedError( @@ -3177,11 +3188,11 @@ def _stack_specs(list_of_spec, dim, out=None): if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] - if isinstance(spec0, TensorSpec): + if isinstance(spec0, TensorSpecBase): device = spec0.device all_equal = True for spec in list_of_spec[1:]: - if not isinstance(spec, TensorSpec): + if not isinstance(spec, TensorSpecBase): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of specs in the list." ) @@ -3232,8 +3243,8 @@ def _stack_composite_specs(list_of_spec, dim, out=None): raise NotImplementedError -@TensorSpec.implements_for_spec(torch.squeeze) -def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: +@TensorSpecBase.implements_for_spec(torch.squeeze) +def _squeeze_spec(spec: TensorSpecBase, *args, **kwargs) -> TensorSpecBase: return spec.squeeze(*args, **kwargs) @@ -3242,8 +3253,8 @@ def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSp return spec.squeeze(*args, **kwargs) -@TensorSpec.implements_for_spec(torch.unsqueeze) -def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: +@TensorSpecBase.implements_for_spec(torch.unsqueeze) +def _unsqueeze_spec(spec: TensorSpecBase, *args, **kwargs) -> TensorSpecBase: return spec.unsqueeze(*args, **kwargs) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index cb26cedc38c..284ecd4868d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -19,7 +19,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING @@ -130,7 +130,7 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): observation_spec. Therefore, "observation_spec" should be thought as a generic data container for environment outputs that are not done or reward data. - reward_spec (TensorSpec): the (leaf) spec of the reward. If the reward + reward_spec (TensorSpecBase): the (leaf) spec of the reward. If the reward is nested within a tensordict, its location can be accessed via the ``reward_key`` attribute: @@ -140,7 +140,7 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): >>> # accessing reward: >>> reward = env.fake_tensordict()[('next', *env.reward_key)] - done_spec (TensorSpec): the (leaf) spec of the done. If the done + done_spec (TensorSpecBase): the (leaf) spec of the done. If the done is nested within a tensordict, its location can be accessed via the ``done_key`` attribute. @@ -150,7 +150,7 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): >>> # accessing done: >>> done = env.fake_tensordict()[('next', *env.done_key)] - action_spec (TensorSpec): the ampling spec of the actions. This attribute + action_spec (TensorSpecBase): the ampling spec of the actions. This attribute is contained in input_spec. >>> # accessing action spec: @@ -397,7 +397,7 @@ def ndim(self): # Parent specs: input and output spec. @property - def input_spec(self) -> TensorSpec: + def input_spec(self) -> TensorSpecBase: input_spec = self.__dict__.get("_input_spec", None) if input_spec is None: input_spec = CompositeSpec( @@ -409,11 +409,11 @@ def input_spec(self) -> TensorSpec: return input_spec @input_spec.setter - def input_spec(self, value: TensorSpec) -> None: + def input_spec(self, value: TensorSpecBase) -> None: raise RuntimeError("input_spec is protected.") @property - def output_spec(self) -> TensorSpec: + def output_spec(self) -> TensorSpecBase: output_spec = self.__dict__.get("_output_spec", None) if output_spec is None: output_spec = CompositeSpec( @@ -424,7 +424,7 @@ def output_spec(self) -> TensorSpec: return output_spec @output_spec.setter - def output_spec(self, value: TensorSpec) -> None: + def output_spec(self, value: TensorSpecBase) -> None: raise RuntimeError("output_spec is protected.") # Action spec @@ -456,7 +456,7 @@ def action_key(self): # Action spec: action specs belong to input_spec @property - def action_spec(self) -> TensorSpec: + def action_spec(self) -> TensorSpecBase: """The ``action`` leaf spec. This property will always return the leaf spec of the action attribute, @@ -486,7 +486,7 @@ def action_spec(self) -> TensorSpec: return out @action_spec.setter - def action_spec(self, value: TensorSpec) -> None: + def action_spec(self, value: TensorSpecBase) -> None: try: self.input_spec.unlock_() device = self.input_spec.device @@ -542,7 +542,7 @@ def reward_key(self): # Done spec: reward specs belong to output_spec @property - def reward_spec(self) -> TensorSpec: + def reward_spec(self) -> TensorSpecBase: """The ``reward`` leaf spec. This property will always return the leaf spec of the reward attribute, @@ -581,7 +581,7 @@ def reward_spec(self) -> TensorSpec: return out @reward_spec.setter - def reward_spec(self, value: TensorSpec) -> None: + def reward_spec(self, value: TensorSpecBase) -> None: try: self.output_spec.unlock_() device = self.output_spec.device @@ -654,7 +654,7 @@ def done_key(self): # Done spec: done specs belong to output_spec @property - def done_spec(self) -> TensorSpec: + def done_spec(self) -> TensorSpecBase: """The ``done`` leaf spec. This property will always return the leaf spec of the done attribute, @@ -692,7 +692,7 @@ def done_spec(self) -> TensorSpec: return out @done_spec.setter - def done_spec(self, value: TensorSpec) -> None: + def done_spec(self, value: TensorSpecBase) -> None: try: self.output_spec.unlock_() device = self.output_spec.device @@ -748,7 +748,7 @@ def observation_spec(self) -> CompositeSpec: return observation_spec @observation_spec.setter - def observation_spec(self, value: TensorSpec) -> None: + def observation_spec(self, value: TensorSpecBase) -> None: try: self.output_spec.unlock_() device = self.output_spec.device diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 41467abd600..cd96c7f603d 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -15,7 +15,7 @@ from tensordict import TensorDict from tensordict.tensordict import TensorDictBase -from torchrl.data.tensor_specs import TensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import TensorSpecBase, UnboundedContinuousTensorSpec from torchrl.envs.common import _EnvWrapper @@ -29,7 +29,7 @@ def __call__( raise NotImplementedError @abc.abstractproperty - def info_spec(self) -> Dict[str, TensorSpec]: + def info_spec(self) -> Dict[str, TensorSpecBase]: raise NotImplementedError @@ -56,7 +56,7 @@ class default_info_dict_reader(BaseInfoDictReader): def __init__( self, keys: List[str] = None, - spec: Union[Sequence[TensorSpec], Dict[str, TensorSpec]] = None, + spec: Union[Sequence[TensorSpecBase], Dict[str, TensorSpecBase]] = None, ): if keys is None: keys = [] @@ -91,7 +91,7 @@ def __call__( return tensordict @property - def info_spec(self) -> Dict[str, TensorSpec]: + def info_spec(self) -> Dict[str, TensorSpecBase]: return self._info_spec diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index af35540d22b..5d6257469f1 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -14,7 +14,7 @@ from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) @@ -52,7 +52,7 @@ def _dmcontrol_to_torchrl_spec_transform( spec, dtype: Optional[torch.dtype] = None, device: DEVICE_TYPING = None, -) -> TensorSpec: +) -> TensorSpecBase: if isinstance(spec, collections.OrderedDict): spec = { k: _dmcontrol_to_torchrl_spec_transform(item, device=device) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 3a15d865602..22b80751c63 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -25,7 +25,7 @@ MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict @@ -180,7 +180,7 @@ def gym_backend(submodule=None): def _gym_to_torchrl_spec_transform( spec, dtype=None, device="cpu", categorical_action_encoding=False -) -> TensorSpec: +) -> TensorSpecBase: """Maps the gym specs to the TorchRL specs. By convention, 'state' keys of Dict specs will be renamed "observation" to match the diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 266225ce77b..41492c7dc07 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -14,7 +14,7 @@ from torch.utils import dlpack as torch_dlpack from torchrl.data.tensor_specs import ( CompositeSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) @@ -105,7 +105,9 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example): return type(object_example)(**t) -def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> TensorSpec: +def _extract_spec( + data: Union[torch.Tensor, TensorDictBase], key=None +) -> TensorSpecBase: if isinstance(data, torch.Tensor): shape = data.shape if key in ("reward", "done"): diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 690b81f2c47..6246bbe3290 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -15,7 +15,7 @@ DEVICE_TYPING, DiscreteTensorSpec, OneHotDiscreteTensorSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) @@ -53,7 +53,7 @@ def _jumanji_to_torchrl_spec_transform( dtype: Optional[torch.dtype] = None, device: DEVICE_TYPING = None, categorical_action_encoding: bool = True, -) -> TensorSpec: +) -> TensorSpecBase: if isinstance(spec, jumanji.specs.DiscreteArray): action_space_cls = ( DiscreteTensorSpec @@ -172,21 +172,21 @@ def _make_state_example(self, env): state = _tree_reshape(state, self.batch_size) return state - def _make_state_spec(self, env) -> TensorSpec: + def _make_state_spec(self, env) -> TensorSpecBase: key = jax.random.PRNGKey(0) state, _ = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) state_spec = _extract_spec(state_dict) return state_spec - def _make_action_spec(self, env) -> TensorSpec: + def _make_action_spec(self, env) -> TensorSpecBase: action_spec = _jumanji_to_torchrl_spec_transform( env.action_spec(), device=self.device ) action_spec = action_spec.expand(*self.batch_size, *action_spec.shape) return action_spec - def _make_observation_spec(self, env) -> TensorSpec: + def _make_observation_spec(self, env) -> TensorSpecBase: spec = env.observation_spec() new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device) if isinstance(spec, jumanji.specs.Array): @@ -198,7 +198,7 @@ def _make_observation_spec(self, env) -> TensorSpec: else: raise TypeError(f"Unsupported spec type {type(spec)}") - def _make_reward_spec(self, env) -> TensorSpec: + def _make_reward_spec(self, env) -> TensorSpecBase: reward_spec = _jumanji_to_torchrl_spec_transform( env.reward_spec(), device=self.device ) diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 1a63b0f5c45..c14f0ede280 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -86,8 +86,8 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): Properties: - observation_spec (CompositeSpec): sampling spec of the observations; - - action_spec (TensorSpec): sampling spec of the actions; - - reward_spec (TensorSpec): sampling spec of the rewards; + - action_spec (TensorSpecBase): sampling spec of the actions; + - reward_spec (TensorSpecBase): sampling spec of the rewards; - input_spec (CompositeSpec): sampling spec of the inputs; - batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. - device (torch.device): device where the env input and output are expected to live diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index f59e247dd48..4d81e561d39 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -12,7 +12,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING @@ -103,7 +103,9 @@ def _apply_transform(self, obs: torch.Tensor) -> None: out = out.view(*shape, *out.shape[1:]) return out - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: if not isinstance(observation_spec, CompositeSpec): raise ValueError("_R3MNet can only infer CompositeSpec") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f31b73ce928..199604fc554 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -27,7 +27,7 @@ DEVICE_TYPING, DiscreteTensorSpec, OneHotDiscreteTensorSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) @@ -275,7 +275,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: This method should generally be left untouched. Changes should be implemented using :meth:`~.transform_observation_spec`, :meth:`~.transform_reward_spec` and :meth:`~.transform_done_spec`. Args: - output_spec (TensorSpec): spec before the transform + output_spec (TensorSpecBase): spec before the transform Returns: expected spec after the transform @@ -295,11 +295,11 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: ) return output_spec - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + def transform_input_spec(self, input_spec: TensorSpecBase) -> TensorSpecBase: """Transforms the input spec such that the resulting spec matches transform mapping. Args: - input_spec (TensorSpec): spec before the transform + input_spec (TensorSpecBase): spec before the transform Returns: expected spec after the transform @@ -307,11 +307,13 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: """ return input_spec - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: """Transforms the observation spec such that the resulting spec matches transform mapping. Args: - observation_spec (TensorSpec): spec before the transform + observation_spec (TensorSpecBase): spec before the transform Returns: expected spec after the transform @@ -319,11 +321,11 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec """ return observation_spec - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: """Transforms the reward spec such that the resulting spec matches transform mapping. Args: - reward_spec (TensorSpec): spec before the transform + reward_spec (TensorSpecBase): spec before the transform Returns: expected spec after the transform @@ -331,11 +333,11 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: """ return reward_spec - def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + def transform_done_spec(self, done_spec: TensorSpecBase) -> TensorSpecBase: """Transforms the done spec such that the resulting spec matches transform mapping. Args: - done_spec (TensorSpec): spec before the transform + done_spec (TensorSpecBase): spec before the transform Returns: expected spec after the transform @@ -562,7 +564,7 @@ def _inplace_update(self): return self.base_env._inplace_update @property - def output_spec(self) -> TensorSpec: + def output_spec(self) -> TensorSpecBase: """Observation spec of the transformed environment.""" if self.__dict__.get("_output_spec", None) is None or not self.cache_specs: output_spec = self.base_env.output_spec.clone() @@ -576,12 +578,12 @@ def output_spec(self) -> TensorSpec: return output_spec @property - def action_spec(self) -> TensorSpec: + def action_spec(self) -> TensorSpecBase: """Action spec of the transformed environment.""" return self.input_spec[("_action_spec", *self.action_key)] @property - def input_spec(self) -> TensorSpec: + def input_spec(self) -> TensorSpecBase: """Action spec of the transformed environment.""" if self.__dict__.get("_input_spec", None) is None or not self.cache_specs: input_spec = self.base_env.input_spec.clone() @@ -595,12 +597,12 @@ def input_spec(self) -> TensorSpec: return input_spec @property - def reward_spec(self) -> TensorSpec: + def reward_spec(self) -> TensorSpecBase: """Reward spec of the transformed environment.""" return self.output_spec[("_reward_spec", *self.reward_key)] @property - def observation_spec(self) -> TensorSpec: + def observation_spec(self) -> TensorSpecBase: """Observation spec of the transformed environment.""" observation_spec = self.output_spec["_observation_spec"] if observation_spec is None: @@ -608,7 +610,7 @@ def observation_spec(self) -> TensorSpec: return observation_spec @property - def state_spec(self) -> TensorSpec: + def state_spec(self) -> TensorSpecBase: """State spec of the transformed environment.""" state_spec = self.input_spec["_state_spec"] if state_spec is None: @@ -616,7 +618,7 @@ def state_spec(self) -> TensorSpec: return state_spec @property - def done_spec(self) -> TensorSpec: + def done_spec(self) -> TensorSpecBase: """Done spec of the transformed environment.""" return self.output_spec[("_done_spec", *self.done_key)] @@ -839,22 +841,24 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = t._inv_call(tensordict) return tensordict - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + def transform_input_spec(self, input_spec: TensorSpecBase) -> TensorSpecBase: for t in self.transforms[::-1]: input_spec = t.transform_input_spec(input_spec) return input_spec - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: for t in self.transforms: observation_spec = t.transform_observation_spec(observation_spec) return observation_spec - def transform_output_spec(self, output_spec: TensorSpec) -> TensorSpec: + def transform_output_spec(self, output_spec: TensorSpecBase) -> TensorSpecBase: for t in self.transforms: output_spec = t.transform_output_spec(output_spec) return output_spec - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: for t in self.transforms: reward_spec = t.transform_reward_spec(reward_spec) return reward_spec @@ -1011,7 +1015,9 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: return observation @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: observation_spec = self._pixel_observation(observation_spec) unsqueeze_dim = [1] if self._should_unsqueeze(observation_spec) else [] observation_spec.shape = torch.Size( @@ -1026,7 +1032,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec observation_spec.dtype = self.dtype return observation_spec - def _should_unsqueeze(self, observation_like: torch.FloatTensor | TensorSpec): + def _should_unsqueeze(self, observation_like: torch.FloatTensor | TensorSpecBase): has_3_dimensions = False if isinstance(observation_like, torch.FloatTensor): has_3_dimensions = observation_like.ndimension() == 3 @@ -1034,7 +1040,7 @@ def _should_unsqueeze(self, observation_like: torch.FloatTensor | TensorSpec): has_3_dimensions = len(observation_like.shape) == 3 return has_3_dimensions and self.unsqueeze - def _pixel_observation(self, spec: TensorSpec) -> None: + def _pixel_observation(self, spec: TensorSpecBase) -> None: if isinstance(spec.space, ContinuousBox): spec.space.maximum = self._apply_transform(spec.space.maximum) spec.space.minimum = self._apply_transform(spec.space.minimum) @@ -1237,7 +1243,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: return reward @_apply_to_composite - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: if isinstance(reward_spec, UnboundedContinuousTensorSpec): return BoundedTensorSpec( self.clamp_min, @@ -1281,7 +1287,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: return (reward > 0.0).to(torch.long) @_apply_to_composite - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: return BinaryDiscreteTensorSpec( n=1, device=reward_spec.device, shape=reward_spec.shape ) @@ -1337,7 +1343,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: return observation @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: space = observation_spec.space if isinstance(space, ContinuousBox): space.minimum = self._apply_transform(space.minimum) @@ -1389,7 +1397,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: return observation @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: space = observation_spec.space if isinstance(space, ContinuousBox): space.minimum = self._apply_transform(space.minimum) @@ -1467,7 +1477,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: forward = ObservationTransform._call @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: space = observation_spec.space if isinstance(space, ContinuousBox): @@ -1549,7 +1561,7 @@ def _inv_apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = observation.squeeze(self.unsqueeze_dim) return observation - def _transform_spec(self, spec: TensorSpec) -> None: + def _transform_spec(self, spec: TensorSpecBase) -> None: space = spec.space if isinstance(space, ContinuousBox): space.minimum = self._apply_transform(space.minimum) @@ -1559,7 +1571,7 @@ def _transform_spec(self, spec: TensorSpec) -> None: spec.shape = self._apply_transform(torch.zeros(spec.shape)).shape return spec - def _inv_transform_spec(self, spec: TensorSpec) -> None: + def _inv_transform_spec(self, spec: TensorSpecBase) -> None: space = spec.space if isinstance(space, ContinuousBox): space.minimum = self._inv_apply_transform(space.minimum) @@ -1570,17 +1582,19 @@ def _inv_transform_spec(self, spec: TensorSpec) -> None: return spec @_apply_to_composite_inv - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + def transform_input_spec(self, input_spec: TensorSpecBase) -> TensorSpecBase: return self._inv_transform_spec(input_spec) @_apply_to_composite - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: if "reward" in self.in_keys: reward_spec = self._transform_spec(reward_spec) return reward_spec @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: return self._transform_spec(observation_spec) def __repr__(self) -> str: @@ -1645,7 +1659,9 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: return observation @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: space = observation_spec.space if isinstance(space, ContinuousBox): space.minimum = self._apply_transform(space.minimum) @@ -1902,7 +1918,9 @@ def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: return obs * scale + loc @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: space = observation_spec.space if isinstance(space, ContinuousBox): space.minimum = self._apply_transform(space.minimum) @@ -1910,7 +1928,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec @_apply_to_composite_inv - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + def transform_input_spec(self, input_spec: TensorSpecBase) -> TensorSpecBase: space = input_spec.space if isinstance(space, ContinuousBox): space.minimum = self._apply_transform(space.minimum) @@ -2126,7 +2144,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: space = observation_spec.space if isinstance(space, ContinuousBox): space.minimum = torch.cat([space.minimum] * self.N, self.dim) @@ -2269,7 +2289,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: return reward @_apply_to_composite - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: if isinstance(reward_spec, UnboundedContinuousTensorSpec): return reward_spec else: @@ -2334,7 +2354,7 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: return obs.to(torch.double) - def _transform_spec(self, spec: TensorSpec) -> None: + def _transform_spec(self, spec: TensorSpecBase) -> None: if isinstance(spec, CompositeSpec): for key in spec: self._transform_spec(spec[key]) @@ -2345,7 +2365,7 @@ def _transform_spec(self, spec: TensorSpec) -> None: space.minimum = space.minimum.to(torch.float) space.maximum = space.maximum.to(torch.float) - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + def transform_input_spec(self, input_spec: TensorSpecBase) -> TensorSpecBase: action_spec = input_spec["_action_spec"] state_spec = input_spec["_state_spec"] for key in self.in_keys_inv: @@ -2363,7 +2383,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec @_apply_to_composite - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + def transform_reward_spec(self, reward_spec: TensorSpecBase) -> TensorSpecBase: if "reward" in self.in_keys: if reward_spec.dtype is not torch.double: raise TypeError("reward_spec.dtype is not double") @@ -2372,7 +2392,9 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: return reward_spec @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: self._transform_spec(observation_spec) return observation_spec @@ -2498,7 +2520,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: forward = _call - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: # check that all keys are in observation_spec if len(self.in_keys) > 1 and not isinstance(observation_spec, CompositeSpec): raise ValueError( @@ -2797,7 +2821,7 @@ class TensorDictPrimer(Transform): primers (dict, optional): a dictionary containing key-spec pairs which will be used to populate the input tensordict. random (bool, optional): if ``True``, the values will be drawn randomly from - the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. + the TensorSpecBase domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. default_value (float, optional): if non-random filling is chosen, this value will be used to populate the tensors. Defaults to `0.0`. @@ -2865,7 +2889,7 @@ def __init__(self, primers: dict = None, random=False, default_value=0.0, **kwar # sanity check for spec in self.primers.values(): - if not isinstance(spec, TensorSpec): + if not isinstance(spec, TensorSpecBase): raise ValueError( "The values of the primers must be a subtype of the TensorSpec class. " f"Got {type(spec)} instead." @@ -3407,7 +3431,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set("next", next_tensordict) return tensordict - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: """Transforms the observation spec, adding the new keys generated by RewardSum.""" # Retrieve parent reward spec reward_spec = self.parent.reward_spec @@ -3613,7 +3639,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def reset(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict.exclude(*self.excluded_keys) - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: if any(key in observation_spec.keys(True, True) for key in self.excluded_keys): return CompositeSpec( **{ @@ -3669,7 +3697,9 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: *self.selected_keys, "reward", "done", *input_keys, strict=False ) - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: return CompositeSpec( **{ key: value @@ -3783,7 +3813,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: return observation_spec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -3944,7 +3976,9 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(self.out_keys[0], _reset.clone()) return tensordict - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: observation_spec[self.out_keys[0]] = DiscreteTensorSpec( 2, dtype=torch.bool, diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 0726b349378..cc71b0d0630 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -12,7 +12,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING @@ -92,7 +92,9 @@ def _apply_transform(self, obs: torch.Tensor) -> None: out = out.view(*shape, *out.shape[1:]) return out - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + def transform_observation_spec( + self, observation_spec: TensorSpecBase + ) -> TensorSpecBase: if not isinstance(observation_spec, CompositeSpec): raise ValueError("_VIPNet can only infer CompositeSpec") diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 30cc756e540..6621f2c341c 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -26,7 +26,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, ) from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -1172,7 +1172,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = self._transform_step_output(step_output) return tensordict_out.select().set("next", tensordict_out) - def _get_action_spec(self) -> TensorSpec: + def _get_action_spec(self) -> TensorSpecBase: # local import to avoid importing gym in the script from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform @@ -1188,7 +1188,7 @@ def _get_action_spec(self) -> TensorSpec: action_spec = self._add_shape_to_spec(action_spec) return action_spec - def _get_output_spec(self) -> TensorSpec: + def _get_output_spec(self) -> TensorSpecBase: return CompositeSpec( _observation_spec=self._get_observation_spec(), _reward_spec=self._get_reward_spec(), @@ -1197,7 +1197,7 @@ def _get_output_spec(self) -> TensorSpec: device=self.device, ) - def _get_observation_spec(self) -> TensorSpec: + def _get_observation_spec(self) -> TensorSpecBase: # local import to avoid importing gym in the script from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform @@ -1216,16 +1216,16 @@ def _get_observation_spec(self) -> TensorSpec: device=self.device, ) - def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec: + def _add_shape_to_spec(self, spec: TensorSpecBase) -> TensorSpecBase: return spec.expand((self.num_workers, *spec.shape)) - def _get_reward_spec(self) -> TensorSpec: + def _get_reward_spec(self) -> TensorSpecBase: return UnboundedContinuousTensorSpec( device=self.device, shape=self.batch_size, ) - def _get_done_spec(self) -> TensorSpec: + def _get_done_spec(self) -> TensorSpecBase: return DiscreteTensorSpec( 2, device=self.device, diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 6965516fde2..b32ec00aea0 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -217,7 +217,7 @@ class RSSMPrior(nn.Module): Reference: https://arxiv.org/abs/1811.04551 Args: - action_spec (TensorSpec): Action spec. + action_spec (TensorSpecBase): Action spec. hidden_dim (int, optional): Number of hidden units in the linear network. Input size of the recurrent network. Defaults to 200. rnn_hidden_dim (int, optional): Number of hidden units in the recurrent network. Also size of the belief. diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index d1aacf19853..b5c6cad1ab0 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -15,7 +15,7 @@ ) from torch import nn -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import CompositeSpec, TensorSpecBase from torchrl.modules.models.models import DistributionalDQNnet from torchrl.modules.tensordict_module.common import SafeModule from torchrl.modules.tensordict_module.probabilistic import ( @@ -46,7 +46,7 @@ class Actor(SafeModule): number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. Defaults to ``["action"]``. - spec (TensorSpec, optional): Keyword-only argument. + spec (TensorSpecBase, optional): Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, spec characterize the space of the first output tensor. @@ -55,7 +55,7 @@ class Actor(SafeModule): input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :obj:`TensorSpecBase.project` method. Default is ``False``. Examples: @@ -92,7 +92,7 @@ def __init__( in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, *, - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, **kwargs, ): if in_keys is None: @@ -135,14 +135,14 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): out_keys (str or iterable of str): keys where the sampled values will be written. Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped. - spec (TensorSpec, optional): keyword-only argument containing the specs + spec (TensorSpecBase, optional): keyword-only argument containing the specs of the output tensor. If the module outputs multiple output tensors, spec characterize the space of the first output tensor. safe (bool): keyword-only argument. if ``True``, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :obj:`TensorSpecBase.project` method. Default is ``False``. default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument. Default method to be used to retrieve @@ -214,7 +214,7 @@ def __init__( in_keys: Union[str, Sequence[str]], out_keys: Optional[Sequence[str]] = None, *, - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, **kwargs, ): if out_keys is None: @@ -318,7 +318,7 @@ class QValueModule(TensorDictModuleBase): It works with both tensordict and regular tensors. Args: - action_space (str or TensorSpec, optional): Action space. Must be one of + action_space (str or TensorSpecBase, optional): Action space. Must be one of ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``, or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, @@ -333,14 +333,14 @@ class QValueModule(TensorDictModuleBase): var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each action component. - spec (TensorSpec, optional): if provided, the specs of the action (and/or + spec (TensorSpecBase, optional): if provided, the specs of the action (and/or other outputs). This is exclusive with ``action_space``, as the spec conditions the action space. safe (bool): if ``True``, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :obj:`TensorSpecBase.project` method. Default is ``False``. Returns: @@ -376,11 +376,11 @@ class QValueModule(TensorDictModuleBase): def __init__( self, - action_space: Optional[Union[str, TensorSpec]], + action_space: Optional[Union[str, TensorSpecBase]], action_value_key: Union[List[str], List[Tuple[str]]] = None, out_keys: Union[List[str], List[Tuple[str]]] = None, var_nums: Optional[int] = None, - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, safe: bool = False, ): action_space, spec = _process_action_space_spec(action_space, spec) @@ -509,7 +509,7 @@ class DistributionalQValueModule(QValueModule): https://arxiv.org/pdf/1707.06887.pdf Args: - action_space (str or TensorSpec, optional): Action space. Must be one of + action_space (str or TensorSpecBase, optional): Action space. Must be one of ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``, or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, @@ -525,14 +525,14 @@ class DistributionalQValueModule(QValueModule): var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each action component. - spec (TensorSpec, optional): if provided, the specs of the action (and/or + spec (TensorSpecBase, optional): if provided, the specs of the action (and/or other outputs). This is exclusive with ``action_space``, as the spec conditions the action space. safe (bool): if ``True``, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :obj:`TensorSpecBase.project` method. Default is ``False``. Examples: @@ -576,7 +576,7 @@ def __init__( action_value_key: Union[List[str], List[Tuple[str]]] = None, out_keys: Union[List[str], List[Tuple[str]]] = None, var_nums: Optional[int] = None, - spec: TensorSpec = None, + spec: TensorSpecBase = None, safe: bool = False, ): if action_value_key is None: @@ -690,13 +690,13 @@ def _process_action_space_spec(action_space, spec): raise ValueError("action_space cannot be of type CompositeSpec.") if ( spec is not None - and isinstance(action_space, TensorSpec) + and isinstance(action_space, TensorSpecBase) and action_space is not spec ): raise ValueError( "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." ) - if isinstance(action_space, TensorSpec): + if isinstance(action_space, TensorSpecBase): spec = action_space action_space = _find_action_space(action_space) # check that the spec and action_space match @@ -887,7 +887,7 @@ class QValueActor(SafeSequential): list of keys indicates what observations need to be passed to the wrapped module to get the action values. Defaults to ``["observation"]``. - spec (TensorSpec, optional): Keyword-only argument. + spec (TensorSpecBase, optional): Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, spec characterize the space of the first output tensor. diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 53e285e58f2..057183ed224 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -16,7 +16,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import CompositeSpec, TensorSpecBase from torchrl.data.utils import DEVICE_TYPING @@ -102,7 +102,7 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): class SafeModule(TensorDictModule): - """:class:`tensordict.nn.TensorDictModule` subclass that accepts a :class:`~torchrl.data.TensorSpec` as argument to control the output domain. + """:class:`tensordict.nn.TensorDictModule` subclass that accepts a :class:`~torchrl.data.TensorSpecBase` as argument to control the output domain. Args: module (nn.Module): a nn.Module used to map the input to the output @@ -118,14 +118,14 @@ class SafeModule(TensorDictModule): The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. - spec (TensorSpec, optional): specs of the output tensor. If the module + spec (TensorSpecBase, optional): specs of the output tensor. If the module outputs multiple output tensors, spec characterize the space of the first output tensor. safe (bool): if ``True``, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :obj:`TensorSpecBase.project` method. Default is ``False``. Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. The domain spec can @@ -202,14 +202,14 @@ def __init__( ], in_keys: Iterable[str], out_keys: Iterable[str], - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, safe: bool = False, ): super().__init__(module, in_keys, out_keys) self.register_spec(safe=safe, spec=spec) def register_spec(self, safe, spec): - if spec is not None and not isinstance(spec, TensorSpec): + if spec is not None and not isinstance(spec, TensorSpecBase): raise TypeError("spec must be a TensorSpec subclass") elif spec is not None and not isinstance(spec, CompositeSpec): if len(self.out_keys) > 1: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 86cf8b8bbc6..12643c525a2 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -13,7 +13,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - TensorSpec, + TensorSpecBase, UnboundedContinuousTensorSpec, ) from torchrl.envs.utils import exploration_type, ExplorationType @@ -42,7 +42,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is "action". - spec (TensorSpec, optional): if provided, the sampled action will be + spec (TensorSpecBase, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. @@ -86,7 +86,7 @@ def __init__( eps_end: float = 0.1, annealing_num_steps: int = 1000, action_key: str = "action", - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, ): super().__init__(policy) self.register_buffer("eps_init", torch.tensor([eps_init])) @@ -173,10 +173,10 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is "action". - spec (TensorSpec, optional): if provided, the sampled action will be + spec (TensorSpecBase, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. - safe (boolean, optional): if False, the TensorSpec can be None. If it + safe (boolean, optional): if False, the TensorSpecBase can be None. If it is set to False but the spec is passed, the projection will still happen. Default is True. @@ -201,7 +201,7 @@ def __init__( mean: float = 0.0, std: float = 1.0, action_key: str = "action", - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, safe: Optional[bool] = True, ): super().__init__(policy) @@ -344,11 +344,11 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): default: 1000 action_key (str): key of the action to be modified. default: "action" - spec (TensorSpec, optional): if provided, the sampled action will be + spec (TensorSpecBase, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space - given the :obj:`TensorSpec.project` heuristic. + given the :obj:`TensorSpecBase.project` heuristic. default: True Examples: @@ -389,7 +389,7 @@ def __init__( sigma_min: Optional[float] = None, n_steps_annealing: int = 1000, action_key: str = "action", - spec: TensorSpec = None, + spec: TensorSpecBase = None, safe: bool = True, key: str = None, ): diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 2a61ce7078c..379367c78fb 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -14,14 +14,14 @@ ) from tensordict.tensordict import TensorDictBase -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import CompositeSpec, TensorSpecBase from torchrl.modules.distributions import Delta from torchrl.modules.tensordict_module.common import _forward_hook_safe_action from torchrl.modules.tensordict_module.sequence import SafeSequential class SafeProbabilisticModule(ProbabilisticTensorDictModule): - """:class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as argument to control the output domain. + """:class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpecBase` as argument to control the output domain. `SafeProbabilisticModule` is a non-parametric module representing a probability distribution. It reads the distribution parameters from an input @@ -58,14 +58,14 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): out_keys (str or iterable of str): keys where the sampled values will be written. Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped. - spec (TensorSpec): specs of the first output tensor. Used when calling + spec (TensorSpecBase): specs of the first output tensor. Used when calling td_module.random() to generate random values in the target space. safe (bool, optional): if ``True``, the value of the sample is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. As for the :obj:`spec` argument, this check will only occur for the distribution sample, but not the other tensors returned by the input module. If the sample is out of bounds, it is - projected back onto the desired space using the `TensorSpec.project` method. + projected back onto the desired space using the `TensorSpecBase.project` method. Default is ``False``. default_interaction_type (str, optional): default method to be used to retrieve the output value. Should be one of: 'mode', 'median', 'mean' or 'random' @@ -99,7 +99,7 @@ def __init__( self, in_keys: Union[str, Sequence[str], dict], out_keys: Union[str, Sequence[str]], - spec: Optional[TensorSpec] = None, + spec: Optional[TensorSpecBase] = None, safe: bool = False, default_interaction_mode: str = None, default_interaction_type: str = InteractionType.MODE, @@ -121,7 +121,7 @@ def __init__( n_empirical_estimate=n_empirical_estimate, ) - if spec is not None and not isinstance(spec, TensorSpec): + if spec is not None and not isinstance(spec, TensorSpecBase): raise TypeError("spec must be a TensorSpec subclass") elif spec is not None and not isinstance(spec, CompositeSpec): if len(self.out_keys) > 1: @@ -196,7 +196,7 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: class SafeProbabilisticTensorDictSequential( ProbabilisticTensorDictSequential, SafeSequential ): - """:class:`tensordict.nn.ProbabilisticTensorDictSequential` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as argument to control the output domain. + """:class:`tensordict.nn.ProbabilisticTensorDictSequential` subclass that accepts a :class:`~torchrl.envs.TensorSpecBase` as argument to control the output domain. Similarly to :obj:`TensorDictSequential`, but enforces that the final module in the sequence is an :obj:`ProbabilisticTensorDictModule` and also exposes ``get_dist`` diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 12f226acc62..0f6b31587d7 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -4,7 +4,7 @@ DiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, - TensorSpec, + TensorSpecBase, ) ACTION_SPACE_MAP = {} @@ -23,7 +23,7 @@ def _find_action_space(action_space): - if isinstance(action_space, TensorSpec): + if isinstance(action_space, TensorSpecBase): if isinstance(action_space, CompositeSpec): action_space = action_space["action"] action_space = type(action_space) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 1344bd28713..a4ec4797061 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -11,7 +11,7 @@ from tensordict.nn import dispatch from tensordict.utils import NestedKey from torch import nn -from torchrl.data.tensor_specs import TensorSpec +from torchrl.data.tensor_specs import TensorSpecBase from torchrl.envs.utils import step_mdp from torchrl.modules.tensordict_module.actors import ( @@ -44,7 +44,7 @@ class DQNLoss(LossModule): delay_value (bool, optional): whether to duplicate the value network into a new target value network to create a double DQN. Default is ``False``. - action_space (str or TensorSpec, optional): Action space. Must be one of + action_space (str or TensorSpecBase, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, @@ -141,7 +141,7 @@ def __init__( loss_function: str = "l2", delay_value: bool = False, gamma: float = None, - action_space: Union[str, TensorSpec] = None, + action_space: Union[str, TensorSpecBase] = None, priority_key: str = None, ) -> None: diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 17f41430217..058d51c0747 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -367,7 +367,7 @@ def _reset(self, tensordict): # In other words, the ``observation_spec`` and related properties are # convenient shortcuts to the content of the output and input spec containers. # -# TorchRL offers multiple :class:`~torchrl.data.TensorSpec` +# TorchRL offers multiple :class:`~torchrl.data.TensorSpecBase` # `subclasses `_ to # encode the environment's input and output characteristics. # From 716f0ac39ef514f5b3315579450a4506dc595cf4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Jun 2023 21:49:52 +0100 Subject: [PATCH 2/2] amend --- torchrl/modules/tensordict_module/exploration.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 99c9d86e9fc..f6add53fae7 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -11,12 +11,7 @@ from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpecBase, - UnboundedContinuousTensorSpec, -) -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import TensorSpecBase from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.tensordict_module.common import _forward_hook_safe_action