Skip to content

Commit 67c3e9a

Browse files
kurtamohlerVincent Moens
authored andcommitted
[Feature] Add EnvBase.all_actions
ghstack-source-id: 7abf9d4 Pull Request resolved: #2780
1 parent 1ed5d29 commit 67c3e9a

File tree

4 files changed

+123
-18
lines changed

4 files changed

+123
-18
lines changed

test/test_env.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4087,6 +4087,63 @@ def test_env_reset_with_hash(self, stateful, include_san):
40874087
td_check = env.reset(td.select("fen_hash"))
40884088
assert (td_check == td).all()
40894089

4090+
@pytest.mark.parametrize("include_fen", [False, True])
4091+
@pytest.mark.parametrize("include_pgn", [False, True])
4092+
@pytest.mark.parametrize("stateful", [False, True])
4093+
@pytest.mark.parametrize("mask_actions", [False, True])
4094+
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4095+
if not stateful and not include_fen and not include_pgn:
4096+
pytest.skip("fen or pgn must be included if not stateful")
4097+
4098+
env = ChessEnv(
4099+
include_fen=include_fen,
4100+
include_pgn=include_pgn,
4101+
stateful=stateful,
4102+
mask_actions=mask_actions,
4103+
)
4104+
td = env.reset()
4105+
4106+
if not mask_actions:
4107+
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
4108+
env.all_actions()
4109+
return
4110+
4111+
# Choose random actions from the output of `all_actions`
4112+
for _ in range(100):
4113+
if stateful:
4114+
all_actions = env.all_actions()
4115+
else:
4116+
# Reset the the initial state first, just to make sure
4117+
# `all_actions` knows how to get the board state from the input.
4118+
env.reset()
4119+
all_actions = env.all_actions(td.clone())
4120+
4121+
# Choose some random actions and make sure they match exactly one of
4122+
# the actions from `all_actions`. This part is not tested when
4123+
# `mask_actions == False`, because `rand_action` can pick illegal
4124+
# actions in that case.
4125+
if mask_actions:
4126+
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
4127+
# it fail to work properly for stateless mode. It doesn't know
4128+
# how to correctly reset the board state to what is given in the
4129+
# tensordict before picking an action. When this is fixed, we
4130+
# can get rid of the two `reset`s below
4131+
if not stateful:
4132+
env.reset(td.clone())
4133+
td_act = td.clone()
4134+
for _ in range(10):
4135+
rand_action = env.rand_action(td_act)
4136+
assert (rand_action["action"] == all_actions["action"]).sum() == 1
4137+
if not stateful:
4138+
env.reset()
4139+
4140+
action_idx = torch.randint(0, all_actions.shape[0], ()).item()
4141+
chosen_action = all_actions[action_idx]
4142+
td = env.step(td.update(chosen_action))["next"]
4143+
4144+
if td["done"]:
4145+
td = env.reset()
4146+
40904147

40914148
class TestCustomEnvs:
40924149
def test_tictactoe_env(self):

torchrl/data/tensor_specs.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -869,12 +869,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
869869
return self.is_in(item)
870870

871871
@abc.abstractmethod
872-
def enumerate(self) -> Any:
872+
def enumerate(self, use_mask: bool = False) -> Any:
873873
"""Returns all the samples that can be obtained from the TensorSpec.
874874
875875
The samples will be stacked along the first dimension.
876876
877877
This method is only implemented for discrete specs.
878+
879+
Args:
880+
use_mask (bool, optional): If ``True`` and the spec has a mask,
881+
samples that are masked are excluded. Default is ``False``.
878882
"""
879883
...
880884

@@ -1315,9 +1319,9 @@ def __eq__(self, other):
13151319
return False
13161320
return True
13171321

1318-
def enumerate(self) -> torch.Tensor | TensorDictBase:
1322+
def enumerate(self, use_mask: bool = False) -> torch.Tensor | TensorDictBase:
13191323
return torch.stack(
1320-
[spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1
1324+
[spec.enumerate(use_mask) for spec in self._specs], dim=self.stack_dim + 1
13211325
)
13221326

13231327
def __len__(self):
@@ -1810,7 +1814,9 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
18101814
return np.array(vals).reshape(tuple(val.shape))
18111815
return val
18121816

1813-
def enumerate(self) -> torch.Tensor:
1817+
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
1818+
if use_mask:
1819+
raise NotImplementedError
18141820
return (
18151821
torch.eye(self.n, dtype=self.dtype, device=self.device)
18161822
.expand(*self.shape, self.n)
@@ -2142,7 +2148,7 @@ def __init__(
21422148
domain=domain,
21432149
)
21442150

2145-
def enumerate(self) -> Any:
2151+
def enumerate(self, use_mask: bool = False) -> Any:
21462152
raise NotImplementedError(
21472153
f"enumerate is not implemented for spec of class {type(self).__name__}."
21482154
)
@@ -2481,7 +2487,7 @@ def __eq__(self, other):
24812487
def cardinality(self) -> Any:
24822488
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
24832489

2484-
def enumerate(self) -> Any:
2490+
def enumerate(self, use_mask: bool = False) -> Any:
24852491
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
24862492

24872493
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
@@ -2779,7 +2785,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
27792785
val.shape[: -self.ndim] + self.shape
27802786
)
27812787

2782-
def enumerate(self) -> Any:
2788+
def enumerate(self, use_mask: bool = False) -> Any:
27832789
raise NotImplementedError("enumerate cannot be called with continuous specs.")
27842790

27852791
def expand(self, *shape):
@@ -2951,9 +2957,9 @@ def __init__(
29512957
def cardinality(self) -> int:
29522958
return torch.as_tensor(self.nvec).prod()
29532959

2954-
def enumerate(self) -> torch.Tensor:
2960+
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
29552961
nvec = self.nvec
2956-
enum_disc = self.to_categorical_spec().enumerate()
2962+
enum_disc = self.to_categorical_spec().enumerate(use_mask)
29572963
enums = torch.cat(
29582964
[
29592965
torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype)
@@ -3417,14 +3423,18 @@ def __init__(
34173423
def _undefined_n(self):
34183424
return self.space.n < 0
34193425

3420-
def enumerate(self) -> torch.Tensor:
3426+
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
34213427
dtype = self.dtype
34223428
if dtype is torch.bool:
34233429
dtype = torch.uint8
3424-
arange = torch.arange(self.n, dtype=dtype, device=self.device)
3430+
n = self.n
3431+
arange = torch.arange(n, dtype=dtype, device=self.device)
3432+
if use_mask and self.mask is not None:
3433+
arange = arange[self.mask]
3434+
n = arange.shape[0]
34253435
if self.ndim:
34263436
arange = arange.view(-1, *(1,) * self.ndim)
3427-
return arange.expand(self.n, *self.shape)
3437+
return arange.expand(n, *self.shape)
34283438

34293439
@property
34303440
def n(self):
@@ -4088,7 +4098,9 @@ def __init__(
40884098
self.update_mask(mask)
40894099
self.remove_singleton = remove_singleton
40904100

4091-
def enumerate(self) -> torch.Tensor:
4101+
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
4102+
if use_mask:
4103+
raise NotImplementedError()
40924104
if self.mask is not None:
40934105
raise RuntimeError(
40944106
"Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
@@ -5136,13 +5148,15 @@ def cardinality(self) -> int:
51365148
n = 0
51375149
return n
51385150

5139-
def enumerate(self) -> TensorDictBase:
5151+
def enumerate(self, use_mask: bool = False) -> TensorDictBase:
51405152
# We are going to use meshgrid to create samples of all the subspecs in here
51415153
# but first let's get rid of the batch size, we'll put it back later
51425154
self_without_batch = self
51435155
while self_without_batch.ndim:
51445156
self_without_batch = self_without_batch[0]
5145-
samples = {key: spec.enumerate() for key, spec in self_without_batch.items()}
5157+
samples = {
5158+
key: spec.enumerate(use_mask) for key, spec in self_without_batch.items()
5159+
}
51465160
if self.data_cls is not None:
51475161
cls = self.data_cls
51485162
else:
@@ -5566,10 +5580,10 @@ def update(self, dict) -> None:
55665580
self[key] = item
55675581
return self
55685582

5569-
def enumerate(self) -> TensorDictBase:
5583+
def enumerate(self, use_mask: bool = False) -> TensorDictBase:
55705584
dim = self.stack_dim
55715585
return LazyStackedTensorDict.maybe_dense_stack(
5572-
[spec.enumerate() for spec in self._specs], dim + 1
5586+
[spec.enumerate(use_mask) for spec in self._specs], dim + 1
55735587
)
55745588

55755589
def __eq__(self, other):

torchrl/envs/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,6 +2831,27 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
28312831
f"got {tensordict.batch_size} and {self.batch_size}"
28322832
)
28332833

2834+
def all_actions(
2835+
self, tensordict: Optional[TensorDictBase] = None
2836+
) -> TensorDictBase:
2837+
"""Generates all possible actions from the action spec.
2838+
2839+
This only works in environments with fully discrete actions.
2840+
2841+
Args:
2842+
tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
2843+
is called with this tensordict.
2844+
2845+
Returns:
2846+
a tensordict object with the "action" entry updated with a batch of
2847+
all possible actions. The actions are stacked together in the
2848+
leading dimension.
2849+
"""
2850+
if tensordict is not None:
2851+
self.reset(tensordict)
2852+
2853+
return self.full_action_spec.enumerate(use_mask=True)
2854+
28342855
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
28352856
"""Performs a random action given the action_spec attribute.
28362857

torchrl/envs/custom/chess.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import importlib.util
88
import io
99
import pathlib
10-
from typing import Dict
10+
from typing import Dict, Optional
1111

1212
import torch
1313
from tensordict import TensorDict, TensorDictBase
@@ -357,6 +357,19 @@ def __init__(
357357
def _is_done(self, board):
358358
return board.is_game_over() | board.is_fifty_moves()
359359

360+
def all_actions(
361+
self, tensordict: Optional[TensorDictBase] = None
362+
) -> TensorDictBase:
363+
if not self.mask_actions:
364+
raise RuntimeError(
365+
(
366+
"Cannot generate legal actions since 'mask_actions=False' was "
367+
"set. If you really want to generate all actions, not just "
368+
"legal ones, call 'env.full_action_spec.enumerate()'."
369+
)
370+
)
371+
return super().all_actions(tensordict)
372+
360373
def _reset(self, tensordict=None):
361374
fen = None
362375
pgn = None

0 commit comments

Comments
 (0)