From 151612aad8ac08716c7b0e2c1673283c71e8e7a4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 5 Mar 2025 11:15:24 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/modules/tensordict_module/actors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index b1db6c6712a..358ef2006d2 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -2371,7 +2371,10 @@ def forward( action_entry = parent_td.get(action_key_orig[-1], None) if action_entry is None: raise self._NO_INIT_ERR - if self.n_steps is not None and action_entry.shape[parent_td.ndim] != self.n_steps: + if ( + self.n_steps is not None + and action_entry.shape[parent_td.ndim] != self.n_steps + ): raise RuntimeError( f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). " f"The action shape was {action_entry.shape}."