6
6
7
7
import importlib .util
8
8
import io
9
+ import pathlib
9
10
from typing import Dict , Optional
10
11
11
12
import torch
12
13
from PIL import Image
13
14
from tensordict import TensorDict , TensorDictBase
14
- from torchrl .data import Categorical , Composite , NonTensor , Unbounded
15
+ from torchrl .data import Bounded , Categorical , Composite , NonTensor , Unbounded
15
16
16
17
from torchrl .envs import EnvBase
17
18
from torchrl .envs .common import _EnvPostInit
@@ -43,39 +44,65 @@ def __call__(cls, *args, **kwargs):
43
44
class ChessEnv (EnvBase , metaclass = _HashMeta ):
44
45
"""A chess environment that follows the TorchRL API.
45
46
47
+ This environment simulates a chess game using the `chess` library. It supports various state representations
48
+ and can be configured to include different types of observations such as SAN, FEN, PGN, and legal moves.
49
+
46
50
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
47
51
48
52
Args:
49
53
stateful (bool): Whether to keep track of the internal state of the board.
50
54
If False, the state will be stored in the observation and passed back
51
55
to the environment on each call. Default: ``True``.
56
+ include_san (bool): Whether to include SAN (Standard Algebraic Notation) in the observations. Default: ``False``.
57
+ include_fen (bool): Whether to include FEN (Forsyth-Edwards Notation) in the observations. Default: ``False``.
58
+ include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
59
+ include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
60
+ include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
61
+ pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
52
62
53
- .. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
54
- Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
55
- valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for
56
- this behavior.
63
+ .. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
64
+ The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves
65
+ being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
57
66
58
67
Examples:
59
- >>> env = ChessEnv()
68
+ >>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True )
60
69
>>> r = env.reset()
61
70
>>> env.rand_step(r)
62
71
TensorDict(
63
72
fields={
64
73
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
65
74
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
66
75
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
67
- hashing : Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
76
+ legal_moves : Tensor(shape=torch.Size([219 ]), device=cpu, dtype=torch.int64, is_shared=False),
68
77
next: TensorDict(
69
78
fields={
70
79
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
71
- fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None),
72
- hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
73
- reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False),
80
+ fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
81
+ legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
82
+ pgn: NonTensorData(data=[Event "?"]
83
+ [Site "?"]
84
+ [Date "????.??.??"]
85
+ [Round "?"]
86
+ [White "?"]
87
+ [Black "?"]
88
+ [Result "*"]
89
+ 1. b3 *, batch_size=torch.Size([]), device=None),
90
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
91
+ san: NonTensorData(data=b3, batch_size=torch.Size([]), device=None),
74
92
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
75
93
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
76
94
batch_size=torch.Size([]),
77
95
device=None,
78
96
is_shared=False),
97
+ pgn: NonTensorData(data=[Event "?"]
98
+ [Site "?"]
99
+ [Date "????.??.??"]
100
+ [Round "?"]
101
+ [White "?"]
102
+ [Black "?"]
103
+ [Result "*"]
104
+ *, batch_size=torch.Size([]), device=None),
105
+ san: NonTensorData(data=[SAN][START], batch_size=torch.Size([]), device=None),
79
106
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
80
107
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
81
108
batch_size=torch.Size([]),
@@ -84,30 +111,46 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
84
111
>>> env.rollout(1000)
85
112
TensorDict(
86
113
fields={
87
- action: Tensor(shape=torch.Size([322 ]), device=cpu, dtype=torch.int64, is_shared=False),
88
- done: Tensor(shape=torch.Size([322 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
114
+ action: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.int64, is_shared=False),
115
+ done: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
89
116
fen: NonTensorStack(
90
117
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
91
- batch_size=torch.Size([322 ]),
118
+ batch_size=torch.Size([352 ]),
92
119
device=None),
93
- hashing : Tensor(shape=torch.Size([322 ]), device=cpu, dtype=torch.int64, is_shared=False),
120
+ legal_moves : Tensor(shape=torch.Size([352, 219 ]), device=cpu, dtype=torch.int64, is_shared=False),
94
121
next: TensorDict(
95
122
fields={
96
- done: Tensor(shape=torch.Size([322 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
123
+ done: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
97
124
fen: NonTensorStack(
98
- ['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ...,
99
- batch_size=torch.Size([322]),
125
+ ['rnbqkbnr/pppppppp/8/8/8/N7/PPPPPPPP/R1BQKBNR b K...,
126
+ batch_size=torch.Size([352]),
127
+ device=None),
128
+ legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
129
+ pgn: NonTensorStack(
130
+ ['[Event "?"]\n [Site "?"]\n [Date "????.??.??"]\n [R...,
131
+ batch_size=torch.Size([352]),
100
132
device=None),
101
- hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
102
- reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False),
103
- terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
104
- turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
105
- batch_size=torch.Size([322]),
133
+ reward: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.float32, is_shared=False),
134
+ san: NonTensorStack(
135
+ ['Na3', 'a5', 'Nb1', 'Nc6', 'a3', 'g6', 'd4', 'd6'...,
136
+ batch_size=torch.Size([352]),
137
+ device=None),
138
+ terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
139
+ turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
140
+ batch_size=torch.Size([352]),
106
141
device=None,
107
142
is_shared=False),
108
- terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
109
- turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
110
- batch_size=torch.Size([322]),
143
+ pgn: NonTensorStack(
144
+ ['[Event "?"]\n [Site "?"]\n [Date "????.??.??"]\n [R...,
145
+ batch_size=torch.Size([352]),
146
+ device=None),
147
+ san: NonTensorStack(
148
+ ['[SAN][START]', 'Na3', 'a5', 'Nb1', 'Nc6', 'a3', ...,
149
+ batch_size=torch.Size([352]),
150
+ device=None),
151
+ terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
152
+ turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
153
+ batch_size=torch.Size([352]),
111
154
device=None,
112
155
is_shared=False)
113
156
@@ -136,13 +179,50 @@ def lib(cls):
136
179
)
137
180
return chess
138
181
182
+ _san_moves = []
183
+
184
+ @_classproperty
185
+ def san_moves (cls ):
186
+ if not cls ._san_moves :
187
+ with open (pathlib .Path (__file__ ).parent / "san_moves.txt" , "r+" ) as f :
188
+ cls ._san_moves .extend (f .read ().split ("\n " ))
189
+ return cls ._san_moves
190
+
191
+ def _legal_moves_to_index (
192
+ self ,
193
+ tensordict : TensorDictBase | None = None ,
194
+ board : "chess.Board" | None = None , # noqa: F821
195
+ return_mask : bool = False ,
196
+ pad : bool = False ,
197
+ ) -> torch .Tensor :
198
+ if not self .stateful and tensordict is not None :
199
+ fen = self ._get_fen (tensordict ).data
200
+ self .board .set_fen (fen )
201
+ board = self .board
202
+ elif board is None :
203
+ board = self .board
204
+ indices = torch .tensor (
205
+ [self ._san_moves .index (board .san (m )) for m in board .legal_moves ],
206
+ dtype = torch .int64 ,
207
+ )
208
+ if return_mask :
209
+ return torch .zeros (len (self .san_moves ), dtype = torch .bool ).index_fill_ (
210
+ 0 , indices , True
211
+ )
212
+ if pad :
213
+ indices = torch .nn .functional .pad (
214
+ indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
215
+ )
216
+ return indices
217
+
139
218
def __init__ (
140
219
self ,
141
220
* ,
142
221
stateful : bool = True ,
143
222
include_san : bool = False ,
144
223
include_fen : bool = False ,
145
224
include_pgn : bool = False ,
225
+ include_legal_moves : bool = False ,
146
226
include_hash : bool = False ,
147
227
pixels : bool = False ,
148
228
):
@@ -154,6 +234,14 @@ def __init__(
154
234
self .include_san = include_san
155
235
self .include_fen = include_fen
156
236
self .include_pgn = include_pgn
237
+ self .include_legal_moves = include_legal_moves
238
+ if include_legal_moves :
239
+ # 218 max possible legal moves per chess board position
240
+ # https://www.stmintz.com/ccc/index.php?id=424966
241
+ # len(self.san_moves)+1 is the padding value
242
+ self .full_observation_spec ["legal_moves" ] = Bounded (
243
+ 0 , 1 + len (self .san_moves ), shape = (218 ,), dtype = torch .int64
244
+ )
157
245
if include_san :
158
246
self .full_observation_spec ["san" ] = NonTensor (shape = (), example_data = "Nc6" )
159
247
if include_pgn :
@@ -186,18 +274,19 @@ def __init__(
186
274
self .full_observation_spec ["pixels" ] = Unbounded (shape = ())
187
275
188
276
self .full_action_spec = Composite (
189
- action = Categorical (n = - 1 , shape = (), dtype = torch .int64 )
277
+ action = Categorical (n = len ( self . san_moves ) , shape = (), dtype = torch .int64 )
190
278
)
191
279
self .full_reward_spec = Composite (
192
- reward = Unbounded (shape = (1 ,), dtype = torch .int32 )
280
+ reward = Unbounded (shape = (1 ,), dtype = torch .float32 )
193
281
)
194
282
# done spec generated automatically
195
283
self .board = chess .Board ()
196
284
if self .stateful :
197
285
self .action_spec .set_provisional_n (len (list (self .board .legal_moves )))
198
286
199
287
def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
200
- self ._set_action_space (tensordict )
288
+ mask = self ._legal_moves_to_index (tensordict , return_mask = True )
289
+ self .action_spec .update_mask (mask )
201
290
return super ().rand_action (tensordict )
202
291
203
292
def _is_done (self , board ):
@@ -208,10 +297,14 @@ def _reset(self, tensordict=None):
208
297
pgn = None
209
298
if tensordict is not None :
210
299
if self .include_fen :
211
- fen = self ._get_fen (tensordict ).data
300
+ fen = self ._get_fen (tensordict )
301
+ if fen is not None :
302
+ fen = fen .data
212
303
dest = tensordict .empty ()
213
304
if self .include_pgn :
214
- fen = self ._get_pgn (tensordict ).data
305
+ pgn = self ._get_pgn (tensordict )
306
+ if pgn is not None :
307
+ pgn = pgn .data
215
308
dest = tensordict .empty ()
216
309
else :
217
310
dest = TensorDict ()
@@ -245,6 +338,9 @@ def _reset(self, tensordict=None):
245
338
pgn = self ._board_to_pgn (self .board )
246
339
dest .set ("pgn" , pgn )
247
340
dest .set ("turn" , turn )
341
+ if self .include_legal_moves :
342
+ moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
343
+ dest .set ("legal_moves" , moves_idx )
248
344
if self .pixels :
249
345
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
250
346
return dest
@@ -296,8 +392,8 @@ def _set_action_space(self, tensordict: TensorDict | None = None):
296
392
297
393
@classmethod
298
394
def _pgn_to_board (
299
- cls , pgn_string : str , board : "chess.Board" | None = None
300
- ) -> "chess.Board" :
395
+ cls , pgn_string : str , board : "chess.Board" | None = None # noqa: F821
396
+ ) -> "chess.Board" : # noqa: F821
301
397
pgn_io = io .StringIO (pgn_string )
302
398
game = cls .lib .pgn .read_game (pgn_io )
303
399
if board is None :
@@ -309,7 +405,7 @@ def _pgn_to_board(
309
405
return board
310
406
311
407
@classmethod
312
- def _board_to_pgn (cls , board : "chess.Board" ) -> str :
408
+ def _board_to_pgn (cls , board : "chess.Board" ) -> str : # noqa: F821
313
409
# Create a new Game object
314
410
game = cls .lib .pgn .Game ()
315
411
@@ -376,11 +472,8 @@ def _step(self, tensordict):
376
472
"Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
377
473
)
378
474
379
- action = list (board .legal_moves )[action ]
380
- san = None
381
- if self .include_san :
382
- san = board .san (action )
383
- board .push (action )
475
+ san = self .san_moves [action ]
476
+ board .push_san (san )
384
477
385
478
self ._set_action_space ()
386
479
@@ -398,22 +491,33 @@ def _step(self, tensordict):
398
491
if san is not None :
399
492
dest .set ("san" , san )
400
493
494
+ if self .include_legal_moves :
495
+ moves_idx = self ._legal_moves_to_index (board = board , pad = True )
496
+ dest .set ("legal_moves" , moves_idx )
497
+
401
498
turn = torch .tensor (board .turn )
499
+ done = self ._is_done (board )
402
500
if board .is_checkmate ():
403
501
# turn flips after every move, even if the game is over
404
- winner = not turn
405
- reward_val = 1 if winner == self .lib .WHITE else - 1
502
+ # winner = not turn
503
+ reward_val = 1 # if winner == self.lib.WHITE else 0
504
+ elif done :
505
+ reward_val = 0.5
406
506
else :
407
- reward_val = 0
507
+ reward_val = 0.0
408
508
409
- reward = torch .tensor ([reward_val ], dtype = torch .int32 )
410
- done = self ._is_done (board )
509
+ reward = torch .tensor ([reward_val ], dtype = torch .float32 )
411
510
dest .set ("reward" , reward )
412
511
dest .set ("turn" , turn )
413
512
dest .set ("done" , [done ])
414
513
dest .set ("terminated" , [done ])
415
514
if self .pixels :
416
515
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
516
+
517
+ if self .stateful :
518
+ # Make sure that rand_action will work next iteration
519
+ self ._set_action_space ()
520
+
417
521
return dest
418
522
419
523
def _set_seed (self , * args , ** kwargs ):
0 commit comments