Skip to content

Commit 58c3847

Browse files
author
Vincent Moens
committed
[Feature] single_<attr>_spec
ghstack-source-id: 27e247e Pull Request resolved: #2549
1 parent 19dbeeb commit 58c3847

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

test/test_env.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3510,6 +3510,22 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
35103510
assert (td[3].get("next") != 0).any()
35113511

35123512

3513+
def test_single_env_spec():
3514+
env = NestedCountingEnv(batch_size=[3, 1, 7])
3515+
assert not env.single_full_action_spec.shape
3516+
assert not env.single_full_done_spec.shape
3517+
assert not env.single_input_spec.shape
3518+
assert not env.single_full_observation_spec.shape
3519+
assert not env.single_output_spec.shape
3520+
assert not env.single_full_reward_spec.shape
3521+
3522+
assert env.single_action_spec.shape
3523+
assert env.single_reward_spec.shape
3524+
3525+
assert env.output_spec.is_in(env.single_output_spec.zeros(env.shape))
3526+
assert env.input_spec.is_in(env.single_input_spec.zeros(env.shape))
3527+
3528+
35133529
if __name__ == "__main__":
35143530
args, unknown = argparse.ArgumentParser().parse_known_args()
35153531
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/common.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,77 @@ def full_state_spec(self) -> Composite:
14801480
def full_state_spec(self, spec: Composite) -> None:
14811481
self.state_spec = spec
14821482

1483+
# Single-env specs can be used to remove the batch size from the spec
1484+
@property
1485+
def batch_dims(self):
1486+
return len(self.batch_size)
1487+
1488+
def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
1489+
if not self.batch_dims:
1490+
return spec
1491+
idx = tuple(0 for _ in range(self.batch_dims))
1492+
return spec[idx]
1493+
1494+
@property
1495+
def single_full_action_spec(self) -> Composite:
1496+
"""Returns the action spec of the env as if it had no batch dimensions."""
1497+
return self._make_single_env_spec(self.full_action_spec)
1498+
1499+
@property
1500+
def single_action_spec(self) -> TensorSpec:
1501+
"""Returns the action spec of the env as if it had no batch dimensions."""
1502+
return self._make_single_env_spec(self.action_spec)
1503+
1504+
@property
1505+
def single_full_observation_spec(self) -> Composite:
1506+
"""Returns the observation spec of the env as if it had no batch dimensions."""
1507+
return self._make_single_env_spec(self.full_action_spec)
1508+
1509+
@property
1510+
def single_observation_spec(self) -> Composite:
1511+
"""Returns the observation spec of the env as if it had no batch dimensions."""
1512+
return self._make_single_env_spec(self.observation_spec)
1513+
1514+
@property
1515+
def single_full_reward_spec(self) -> Composite:
1516+
"""Returns the reward spec of the env as if it had no batch dimensions."""
1517+
return self._make_single_env_spec(self.full_action_spec)
1518+
1519+
@property
1520+
def single_reward_spec(self) -> TensorSpec:
1521+
"""Returns the reward spec of the env as if it had no batch dimensions."""
1522+
return self._make_single_env_spec(self.reward_spec)
1523+
1524+
@property
1525+
def single_full_done_spec(self) -> Composite:
1526+
"""Returns the done spec of the env as if it had no batch dimensions."""
1527+
return self._make_single_env_spec(self.full_action_spec)
1528+
1529+
@property
1530+
def single_done_spec(self) -> TensorSpec:
1531+
"""Returns the done spec of the env as if it had no batch dimensions."""
1532+
return self._make_single_env_spec(self.done_spec)
1533+
1534+
@property
1535+
def single_output_spec(self) -> Composite:
1536+
"""Returns the output spec of the env as if it had no batch dimensions."""
1537+
return self._make_single_env_spec(self.output_spec)
1538+
1539+
@property
1540+
def single_input_spec(self) -> Composite:
1541+
"""Returns the input spec of the env as if it had no batch dimensions."""
1542+
return self._make_single_env_spec(self.input_spec)
1543+
1544+
@property
1545+
def single_full_state_spec(self) -> Composite:
1546+
"""Returns the state spec of the env as if it had no batch dimensions."""
1547+
return self._make_single_env_spec(self.full_state_spec)
1548+
1549+
@property
1550+
def single_state_spec(self) -> TensorSpec:
1551+
"""Returns the state spec of the env as if it had no batch dimensions."""
1552+
return self._make_single_env_spec(self.state_spec)
1553+
14831554
def step(self, tensordict: TensorDictBase) -> TensorDictBase:
14841555
"""Makes a step in the environment.
14851556

0 commit comments

Comments
 (0)