Skip to content

Commit b70f558

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 6f1d43d + cad2efd commit b70f558

File tree

3 files changed

+29440
-45
lines changed

3 files changed

+29440
-45
lines changed

torchrl/envs/custom/chess.py

Lines changed: 147 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
import importlib.util
88
import io
9+
import pathlib
910
from typing import Dict, Optional
1011

1112
import torch
1213
from PIL import Image
1314
from tensordict import TensorDict, TensorDictBase
14-
from torchrl.data import Categorical, Composite, NonTensor, Unbounded
15+
from torchrl.data import Bounded, Categorical, Composite, NonTensor, Unbounded
1516

1617
from torchrl.envs import EnvBase
1718
from torchrl.envs.common import _EnvPostInit
@@ -43,39 +44,65 @@ def __call__(cls, *args, **kwargs):
4344
class ChessEnv(EnvBase, metaclass=_HashMeta):
4445
"""A chess environment that follows the TorchRL API.
4546
47+
This environment simulates a chess game using the `chess` library. It supports various state representations
48+
and can be configured to include different types of observations such as SAN, FEN, PGN, and legal moves.
49+
4650
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
4751
4852
Args:
4953
stateful (bool): Whether to keep track of the internal state of the board.
5054
If False, the state will be stored in the observation and passed back
5155
to the environment on each call. Default: ``True``.
56+
include_san (bool): Whether to include SAN (Standard Algebraic Notation) in the observations. Default: ``False``.
57+
include_fen (bool): Whether to include FEN (Forsyth-Edwards Notation) in the observations. Default: ``False``.
58+
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
59+
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
60+
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
61+
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
5262
53-
.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
54-
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
55-
valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for
56-
this behavior.
63+
.. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
64+
The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves
65+
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
5766
5867
Examples:
59-
>>> env = ChessEnv()
68+
>>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
6069
>>> r = env.reset()
6170
>>> env.rand_step(r)
6271
TensorDict(
6372
fields={
6473
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
6574
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
6675
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
67-
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
76+
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
6877
next: TensorDict(
6978
fields={
7079
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
71-
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None),
72-
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
73-
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False),
80+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
81+
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
82+
pgn: NonTensorData(data=[Event "?"]
83+
[Site "?"]
84+
[Date "????.??.??"]
85+
[Round "?"]
86+
[White "?"]
87+
[Black "?"]
88+
[Result "*"]
89+
1. b3 *, batch_size=torch.Size([]), device=None),
90+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
91+
san: NonTensorData(data=b3, batch_size=torch.Size([]), device=None),
7492
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
7593
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
7694
batch_size=torch.Size([]),
7795
device=None,
7896
is_shared=False),
97+
pgn: NonTensorData(data=[Event "?"]
98+
[Site "?"]
99+
[Date "????.??.??"]
100+
[Round "?"]
101+
[White "?"]
102+
[Black "?"]
103+
[Result "*"]
104+
*, batch_size=torch.Size([]), device=None),
105+
san: NonTensorData(data=[SAN][START], batch_size=torch.Size([]), device=None),
79106
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
80107
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
81108
batch_size=torch.Size([]),
@@ -84,30 +111,46 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
84111
>>> env.rollout(1000)
85112
TensorDict(
86113
fields={
87-
action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
88-
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
114+
action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
115+
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
89116
fen: NonTensorStack(
90117
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
91-
batch_size=torch.Size([322]),
118+
batch_size=torch.Size([352]),
92119
device=None),
93-
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
120+
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
94121
next: TensorDict(
95122
fields={
96-
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
123+
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
97124
fen: NonTensorStack(
98-
['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ...,
99-
batch_size=torch.Size([322]),
125+
['rnbqkbnr/pppppppp/8/8/8/N7/PPPPPPPP/R1BQKBNR b K...,
126+
batch_size=torch.Size([352]),
127+
device=None),
128+
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
129+
pgn: NonTensorStack(
130+
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
131+
batch_size=torch.Size([352]),
100132
device=None),
101-
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
102-
reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False),
103-
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
104-
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
105-
batch_size=torch.Size([322]),
133+
reward: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.float32, is_shared=False),
134+
san: NonTensorStack(
135+
['Na3', 'a5', 'Nb1', 'Nc6', 'a3', 'g6', 'd4', 'd6'...,
136+
batch_size=torch.Size([352]),
137+
device=None),
138+
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
139+
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
140+
batch_size=torch.Size([352]),
106141
device=None,
107142
is_shared=False),
108-
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
109-
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
110-
batch_size=torch.Size([322]),
143+
pgn: NonTensorStack(
144+
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
145+
batch_size=torch.Size([352]),
146+
device=None),
147+
san: NonTensorStack(
148+
['[SAN][START]', 'Na3', 'a5', 'Nb1', 'Nc6', 'a3', ...,
149+
batch_size=torch.Size([352]),
150+
device=None),
151+
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
152+
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
153+
batch_size=torch.Size([352]),
111154
device=None,
112155
is_shared=False)
113156
@@ -136,13 +179,50 @@ def lib(cls):
136179
)
137180
return chess
138181

182+
_san_moves = []
183+
184+
@_classproperty
185+
def san_moves(cls):
186+
if not cls._san_moves:
187+
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
188+
cls._san_moves.extend(f.read().split("\n"))
189+
return cls._san_moves
190+
191+
def _legal_moves_to_index(
192+
self,
193+
tensordict: TensorDictBase | None = None,
194+
board: "chess.Board" | None = None, # noqa: F821
195+
return_mask: bool = False,
196+
pad: bool = False,
197+
) -> torch.Tensor:
198+
if not self.stateful and tensordict is not None:
199+
fen = self._get_fen(tensordict).data
200+
self.board.set_fen(fen)
201+
board = self.board
202+
elif board is None:
203+
board = self.board
204+
indices = torch.tensor(
205+
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
206+
dtype=torch.int64,
207+
)
208+
if return_mask:
209+
return torch.zeros(len(self.san_moves), dtype=torch.bool).index_fill_(
210+
0, indices, True
211+
)
212+
if pad:
213+
indices = torch.nn.functional.pad(
214+
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
215+
)
216+
return indices
217+
139218
def __init__(
140219
self,
141220
*,
142221
stateful: bool = True,
143222
include_san: bool = False,
144223
include_fen: bool = False,
145224
include_pgn: bool = False,
225+
include_legal_moves: bool = False,
146226
include_hash: bool = False,
147227
pixels: bool = False,
148228
):
@@ -154,6 +234,14 @@ def __init__(
154234
self.include_san = include_san
155235
self.include_fen = include_fen
156236
self.include_pgn = include_pgn
237+
self.include_legal_moves = include_legal_moves
238+
if include_legal_moves:
239+
# 218 max possible legal moves per chess board position
240+
# https://www.stmintz.com/ccc/index.php?id=424966
241+
# len(self.san_moves)+1 is the padding value
242+
self.full_observation_spec["legal_moves"] = Bounded(
243+
0, 1 + len(self.san_moves), shape=(218,), dtype=torch.int64
244+
)
157245
if include_san:
158246
self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6")
159247
if include_pgn:
@@ -186,18 +274,19 @@ def __init__(
186274
self.full_observation_spec["pixels"] = Unbounded(shape=())
187275

188276
self.full_action_spec = Composite(
189-
action=Categorical(n=-1, shape=(), dtype=torch.int64)
277+
action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64)
190278
)
191279
self.full_reward_spec = Composite(
192-
reward=Unbounded(shape=(1,), dtype=torch.int32)
280+
reward=Unbounded(shape=(1,), dtype=torch.float32)
193281
)
194282
# done spec generated automatically
195283
self.board = chess.Board()
196284
if self.stateful:
197285
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))
198286

199287
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
200-
self._set_action_space(tensordict)
288+
mask = self._legal_moves_to_index(tensordict, return_mask=True)
289+
self.action_spec.update_mask(mask)
201290
return super().rand_action(tensordict)
202291

203292
def _is_done(self, board):
@@ -208,10 +297,14 @@ def _reset(self, tensordict=None):
208297
pgn = None
209298
if tensordict is not None:
210299
if self.include_fen:
211-
fen = self._get_fen(tensordict).data
300+
fen = self._get_fen(tensordict)
301+
if fen is not None:
302+
fen = fen.data
212303
dest = tensordict.empty()
213304
if self.include_pgn:
214-
fen = self._get_pgn(tensordict).data
305+
pgn = self._get_pgn(tensordict)
306+
if pgn is not None:
307+
pgn = pgn.data
215308
dest = tensordict.empty()
216309
else:
217310
dest = TensorDict()
@@ -245,6 +338,9 @@ def _reset(self, tensordict=None):
245338
pgn = self._board_to_pgn(self.board)
246339
dest.set("pgn", pgn)
247340
dest.set("turn", turn)
341+
if self.include_legal_moves:
342+
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
343+
dest.set("legal_moves", moves_idx)
248344
if self.pixels:
249345
dest.set("pixels", self._get_tensor_image(board=self.board))
250346
return dest
@@ -296,8 +392,8 @@ def _set_action_space(self, tensordict: TensorDict | None = None):
296392

297393
@classmethod
298394
def _pgn_to_board(
299-
cls, pgn_string: str, board: "chess.Board" | None = None
300-
) -> "chess.Board":
395+
cls, pgn_string: str, board: "chess.Board" | None = None # noqa: F821
396+
) -> "chess.Board": # noqa: F821
301397
pgn_io = io.StringIO(pgn_string)
302398
game = cls.lib.pgn.read_game(pgn_io)
303399
if board is None:
@@ -309,7 +405,7 @@ def _pgn_to_board(
309405
return board
310406

311407
@classmethod
312-
def _board_to_pgn(cls, board: "chess.Board") -> str:
408+
def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
313409
# Create a new Game object
314410
game = cls.lib.pgn.Game()
315411

@@ -376,11 +472,8 @@ def _step(self, tensordict):
376472
"Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
377473
)
378474

379-
action = list(board.legal_moves)[action]
380-
san = None
381-
if self.include_san:
382-
san = board.san(action)
383-
board.push(action)
475+
san = self.san_moves[action]
476+
board.push_san(san)
384477

385478
self._set_action_space()
386479

@@ -398,22 +491,33 @@ def _step(self, tensordict):
398491
if san is not None:
399492
dest.set("san", san)
400493

494+
if self.include_legal_moves:
495+
moves_idx = self._legal_moves_to_index(board=board, pad=True)
496+
dest.set("legal_moves", moves_idx)
497+
401498
turn = torch.tensor(board.turn)
499+
done = self._is_done(board)
402500
if board.is_checkmate():
403501
# turn flips after every move, even if the game is over
404-
winner = not turn
405-
reward_val = 1 if winner == self.lib.WHITE else -1
502+
# winner = not turn
503+
reward_val = 1 # if winner == self.lib.WHITE else 0
504+
elif done:
505+
reward_val = 0.5
406506
else:
407-
reward_val = 0
507+
reward_val = 0.0
408508

409-
reward = torch.tensor([reward_val], dtype=torch.int32)
410-
done = self._is_done(board)
509+
reward = torch.tensor([reward_val], dtype=torch.float32)
411510
dest.set("reward", reward)
412511
dest.set("turn", turn)
413512
dest.set("done", [done])
414513
dest.set("terminated", [done])
415514
if self.pixels:
416515
dest.set("pixels", self._get_tensor_image(board=self.board))
516+
517+
if self.stateful:
518+
# Make sure that rand_action will work next iteration
519+
self._set_action_space()
520+
417521
return dest
418522

419523
def _set_seed(self, *args, **kwargs):

0 commit comments

Comments
 (0)