|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +from typing import Dict, Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +from tensordict import TensorDict, TensorDictBase |
| 9 | +from torchrl.data import Categorical, Composite, NonTensor, Unbounded |
| 10 | + |
| 11 | +from torchrl.envs import EnvBase |
| 12 | + |
| 13 | +from torchrl.envs.utils import _classproperty |
| 14 | + |
| 15 | + |
| 16 | +class ChessEnv(EnvBase): |
| 17 | + """A chess environment that follows the TorchRL API. |
| 18 | +
|
| 19 | + Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__. |
| 20 | +
|
| 21 | + Args: |
| 22 | + stateful (bool): Whether to keep track of the internal state of the board. |
| 23 | + If False, the state will be stored in the observation and passed back |
| 24 | + to the environment on each call. Default: ``False``. |
| 25 | +
|
| 26 | + .. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape. |
| 27 | + Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves, |
| 28 | + valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for |
| 29 | + this behavior. |
| 30 | +
|
| 31 | + Examples: |
| 32 | + >>> env = ChessEnv() |
| 33 | + >>> r = env.reset() |
| 34 | + >>> env.rand_step(r) |
| 35 | + TensorDict( |
| 36 | + fields={ |
| 37 | + action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), |
| 38 | + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 39 | + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None), |
| 40 | + hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), |
| 41 | + next: TensorDict( |
| 42 | + fields={ |
| 43 | + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 44 | + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None), |
| 45 | + hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), |
| 46 | + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False), |
| 47 | + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 48 | + turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 49 | + batch_size=torch.Size([]), |
| 50 | + device=None, |
| 51 | + is_shared=False), |
| 52 | + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 53 | + turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 54 | + batch_size=torch.Size([]), |
| 55 | + device=None, |
| 56 | + is_shared=False) |
| 57 | + >>> env.rollout(1000) |
| 58 | + TensorDict( |
| 59 | + fields={ |
| 60 | + action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), |
| 61 | + done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 62 | + fen: NonTensorStack( |
| 63 | + ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ..., |
| 64 | + batch_size=torch.Size([322]), |
| 65 | + device=None), |
| 66 | + hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), |
| 67 | + next: TensorDict( |
| 68 | + fields={ |
| 69 | + done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 70 | + fen: NonTensorStack( |
| 71 | + ['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ..., |
| 72 | + batch_size=torch.Size([322]), |
| 73 | + device=None), |
| 74 | + hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), |
| 75 | + reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False), |
| 76 | + terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 77 | + turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 78 | + batch_size=torch.Size([322]), |
| 79 | + device=None, |
| 80 | + is_shared=False), |
| 81 | + terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 82 | + turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 83 | + batch_size=torch.Size([322]), |
| 84 | + device=None, |
| 85 | + is_shared=False) |
| 86 | +
|
| 87 | +
|
| 88 | + """ |
| 89 | + |
| 90 | + _hash_table: Dict[int, str] = {} |
| 91 | + |
| 92 | + @_classproperty |
| 93 | + def lib(cls): |
| 94 | + try: |
| 95 | + import chess |
| 96 | + except ImportError: |
| 97 | + raise ImportError( |
| 98 | + "The `chess` library could not be found. Make sure you installed it through `pip install chess`." |
| 99 | + ) |
| 100 | + return chess |
| 101 | + |
| 102 | + def __init__(self, stateful: bool = False): |
| 103 | + chess = self.lib |
| 104 | + super().__init__() |
| 105 | + self.full_observation_spec = Composite( |
| 106 | + hashing=Unbounded(shape=(), dtype=torch.int64), |
| 107 | + fen=NonTensor(shape=()), |
| 108 | + turn=Categorical(n=2, dtype=torch.bool, shape=()), |
| 109 | + ) |
| 110 | + self.stateful = stateful |
| 111 | + if not self.stateful: |
| 112 | + self.full_state_spec = self.full_observation_spec.clone() |
| 113 | + self.full_action_spec = Composite( |
| 114 | + action=Categorical(n=-1, shape=(), dtype=torch.int64) |
| 115 | + ) |
| 116 | + self.full_reward_spec = Composite( |
| 117 | + reward=Unbounded(shape=(1,), dtype=torch.int32) |
| 118 | + ) |
| 119 | + # done spec generated automatically |
| 120 | + self.board = chess.Board() |
| 121 | + if self.stateful: |
| 122 | + self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) |
| 123 | + |
| 124 | + def rand_action(self, tensordict: Optional[TensorDictBase] = None): |
| 125 | + self._set_action_space(tensordict) |
| 126 | + return super().rand_action(tensordict) |
| 127 | + |
| 128 | + def _reset(self, tensordict=None): |
| 129 | + fen = None |
| 130 | + if tensordict is not None: |
| 131 | + fen = self._get_fen(tensordict) |
| 132 | + dest = tensordict.empty() |
| 133 | + else: |
| 134 | + dest = TensorDict() |
| 135 | + |
| 136 | + if fen is None: |
| 137 | + self.board.reset() |
| 138 | + fen = self.board.fen() |
| 139 | + else: |
| 140 | + self.board.set_fen(fen.data) |
| 141 | + |
| 142 | + hashing = hash(fen) |
| 143 | + |
| 144 | + self._set_action_space() |
| 145 | + turn = self.board.turn |
| 146 | + return dest.set("fen", fen).set("hashing", hashing).set("turn", turn) |
| 147 | + |
| 148 | + def _set_action_space(self, tensordict: TensorDict | None = None): |
| 149 | + if not self.stateful and tensordict is not None: |
| 150 | + fen = self._get_fen(tensordict).data |
| 151 | + self.board.set_fen(fen) |
| 152 | + self.action_spec.set_provisional_n(self.board.legal_moves.count()) |
| 153 | + |
| 154 | + @classmethod |
| 155 | + def _get_fen(cls, tensordict): |
| 156 | + fen = tensordict.get("fen", None) |
| 157 | + if fen is None: |
| 158 | + hashing = tensordict.get("hashing", None) |
| 159 | + if hashing is not None: |
| 160 | + fen = cls._hash_table.get(hashing.item()) |
| 161 | + return fen |
| 162 | + |
| 163 | + def _step(self, tensordict): |
| 164 | + # action |
| 165 | + action = tensordict.get("action") |
| 166 | + board = self.board |
| 167 | + if not self.stateful: |
| 168 | + fen = self._get_fen(tensordict).data |
| 169 | + board.set_fen(fen) |
| 170 | + action = str(list(board.legal_moves)[action]) |
| 171 | + # assert chess.Move.from_uci(action) in board.legal_moves |
| 172 | + board.push_san(action) |
| 173 | + self._set_action_space() |
| 174 | + |
| 175 | + # Collect data |
| 176 | + fen = self.board.fen() |
| 177 | + dest = tensordict.empty() |
| 178 | + hashing = hash(fen) |
| 179 | + dest.set("fen", fen) |
| 180 | + dest.set("hashing", hashing) |
| 181 | + |
| 182 | + done = board.is_checkmate() |
| 183 | + turn = torch.tensor(board.turn) |
| 184 | + reward = torch.tensor([done]).int() * (turn.int() * 2 - 1) |
| 185 | + done = done | board.is_stalemate() | board.is_game_over() |
| 186 | + dest.set("reward", reward) |
| 187 | + dest.set("turn", turn) |
| 188 | + dest.set("done", [done]) |
| 189 | + dest.set("terminated", [done]) |
| 190 | + return dest |
| 191 | + |
| 192 | + def _set_seed(self, *args, **kwargs): |
| 193 | + ... |
| 194 | + |
| 195 | + def cardinality(self, tensordict: TensorDictBase | None = None) -> int: |
| 196 | + self._set_action_space(tensordict) |
| 197 | + return self.action_spec.cardinality() |
0 commit comments