From 764e2e06420ec2f917a43b7fa90efa7bffba55f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 17 Jan 2025 13:29:42 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/envs/transforms/transforms.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index aa8c64bacc3..93f935a49fd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec: input_spec = self.__dict__.get("_input_spec", None) return input_spec + def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict: + if self.base_env.rand_action is not EnvBase.rand_action: + # TODO: this will fail if the transform modifies the input. + # For instance, if PendulumEnv overrides rand_action and we build a + # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4)) + # env.rand_action will NOT have a discrete action! + # Getting a discrete action would require coding the inverse transform of an action within + # ActionDiscretizer (ie, float->int, not int->float). + return self.base_env.rand_action(tensordict) + return super().rand_action(tensordict) + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # No need to clone here because inv does it already # tensordict = tensordict.clone(False)