7
7
import importlib .util
8
8
import io
9
9
import pathlib
10
- from typing import Dict , Optional
10
+ from typing import Dict
11
11
12
12
import torch
13
13
from PIL import Image
14
14
from tensordict import TensorDict , TensorDictBase
15
- from torchrl .data import Bounded , Categorical , Composite , NonTensor , Unbounded
15
+ from torchrl .data import Binary , Bounded , Categorical , Composite , NonTensor , Unbounded
16
16
17
17
from torchrl .envs import EnvBase
18
18
from torchrl .envs .common import _EnvPostInit
19
19
20
20
from torchrl .envs .utils import _classproperty
21
21
22
22
23
- class _HashMeta (_EnvPostInit ):
23
+ class _ChessMeta (_EnvPostInit ):
24
24
def __call__ (cls , * args , ** kwargs ):
25
25
instance = super ().__call__ (* args , ** kwargs )
26
26
if kwargs .get ("include_hash" ):
@@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
37
37
if instance .include_pgn :
38
38
in_keys .append ("pgn" )
39
39
out_keys .append ("pgn_hash" )
40
- return instance .append_transform (Hash (in_keys , out_keys ))
40
+ instance = instance .append_transform (Hash (in_keys , out_keys ))
41
+ if kwargs .get ("mask_actions" , True ):
42
+ from torchrl .envs import ActionMask
43
+
44
+ instance = instance .append_transform (ActionMask ())
41
45
return instance
42
46
43
47
44
- class ChessEnv (EnvBase , metaclass = _HashMeta ):
48
+ class ChessEnv (EnvBase , metaclass = _ChessMeta ):
45
49
r"""A chess environment that follows the TorchRL API.
46
50
47
51
This environment simulates a chess game using the `chess` library. It supports various state representations
@@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
63
67
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
64
68
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
65
69
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
70
+ mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
71
+ to the env to make sure that the actions are properly masked. Default: ``True``.
66
72
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
67
73
68
74
.. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
@@ -200,16 +206,15 @@ def _legal_moves_to_index(
200
206
) -> torch .Tensor :
201
207
if not self .stateful :
202
208
if tensordict is None :
203
- raise RuntimeError (
204
- "rand_action requires a tensordict when stateful is False."
205
- )
206
- if self .include_fen :
207
- fen = self ._get_fen (tensordict )
209
+ # trust the board
210
+ pass
211
+ elif self .include_fen :
212
+ fen = tensordict .get ("fen" , None )
208
213
fen = fen .data
209
214
self .board .set_fen (fen )
210
215
board = self .board
211
216
elif self .include_pgn :
212
- pgn = self . _get_pgn ( tensordict )
217
+ pgn = tensordict . get ( "pgn" )
213
218
pgn = pgn .data
214
219
board = self ._pgn_to_board (pgn , self .board )
215
220
@@ -222,15 +227,19 @@ def _legal_moves_to_index(
222
227
)
223
228
224
229
if return_mask :
225
- return torch .zeros (len (self .san_moves ), dtype = torch .bool ).index_fill_ (
226
- 0 , indices , True
227
- )
230
+ return self ._move_index_to_mask (indices )
228
231
if pad :
229
232
indices = torch .nn .functional .pad (
230
233
indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
231
234
)
232
235
return indices
233
236
237
+ @classmethod
238
+ def _move_index_to_mask (cls , indices : torch .Tensor ) -> torch .Tensor :
239
+ return torch .zeros (len (cls .san_moves ), dtype = torch .bool ).index_fill_ (
240
+ 0 , indices , True
241
+ )
242
+
234
243
def __init__ (
235
244
self ,
236
245
* ,
@@ -240,6 +249,7 @@ def __init__(
240
249
include_pgn : bool = False ,
241
250
include_legal_moves : bool = False ,
242
251
include_hash : bool = False ,
252
+ mask_actions : bool = True ,
243
253
pixels : bool = False ,
244
254
):
245
255
chess = self .lib
@@ -250,6 +260,7 @@ def __init__(
250
260
self .include_san = include_san
251
261
self .include_fen = include_fen
252
262
self .include_pgn = include_pgn
263
+ self .mask_actions = mask_actions
253
264
self .include_legal_moves = include_legal_moves
254
265
if include_legal_moves :
255
266
# 218 max possible legal moves per chess board position
@@ -274,8 +285,10 @@ def __init__(
274
285
275
286
self .stateful = stateful
276
287
277
- if not self .stateful :
278
- self .full_state_spec = self .full_observation_spec .clone ()
288
+ # state_spec is loosely defined as such - it's not really an issue that extra keys
289
+ # can go missing but it allows us to reset the env using fen passed to the reset
290
+ # method.
291
+ self .full_state_spec = self .full_observation_spec .clone ()
279
292
280
293
self .pixels = pixels
281
294
if pixels :
@@ -295,16 +308,16 @@ def __init__(
295
308
self .full_reward_spec = Composite (
296
309
reward = Unbounded (shape = (1 ,), dtype = torch .float32 )
297
310
)
311
+ if self .mask_actions :
312
+ self .full_observation_spec ["action_mask" ] = Binary (
313
+ n = len (self .san_moves ), dtype = torch .bool
314
+ )
315
+
298
316
# done spec generated automatically
299
317
self .board = chess .Board ()
300
318
if self .stateful :
301
319
self .action_spec .set_provisional_n (len (list (self .board .legal_moves )))
302
320
303
- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
304
- mask = self ._legal_moves_to_index (tensordict , return_mask = True )
305
- self .action_spec .update_mask (mask )
306
- return super ().rand_action (tensordict )
307
-
308
321
def _is_done (self , board ):
309
322
return board .is_game_over () | board .is_fifty_moves ()
310
323
@@ -314,11 +327,11 @@ def _reset(self, tensordict=None):
314
327
if tensordict is not None :
315
328
dest = tensordict .empty ()
316
329
if self .include_fen :
317
- fen = self . _get_fen ( tensordict )
330
+ fen = tensordict . get ( "fen" , None )
318
331
if fen is not None :
319
332
fen = fen .data
320
333
elif self .include_pgn :
321
- pgn = self . _get_pgn ( tensordict )
334
+ pgn = tensordict . get ( "pgn" , None )
322
335
if pgn is not None :
323
336
pgn = pgn .data
324
337
else :
@@ -358,13 +371,18 @@ def _reset(self, tensordict=None):
358
371
if self .include_legal_moves :
359
372
moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
360
373
dest .set ("legal_moves" , moves_idx )
374
+ if self .mask_actions :
375
+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
376
+ elif self .mask_actions :
377
+ dest .set (
378
+ "action_mask" ,
379
+ self ._legal_moves_to_index (
380
+ board = self .board , pad = True , return_mask = True
381
+ ),
382
+ )
383
+
361
384
if self .pixels :
362
385
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
363
-
364
- if self .stateful :
365
- mask = self ._legal_moves_to_index (dest , return_mask = True )
366
- self .action_spec .update_mask (mask )
367
-
368
386
return dest
369
387
370
388
_cairosvg_lib = None
@@ -435,16 +453,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
435
453
pgn_string = str (game )
436
454
return pgn_string
437
455
438
- @classmethod
439
- def _get_fen (cls , tensordict ):
440
- fen = tensordict .get ("fen" , None )
441
- return fen
442
-
443
- @classmethod
444
- def _get_pgn (cls , tensordict ):
445
- pgn = tensordict .get ("pgn" , None )
446
- return pgn
447
-
448
456
def get_legal_moves (self , tensordict = None , uci = False ):
449
457
"""List the legal moves in a position.
450
458
@@ -468,7 +476,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
468
476
raise ValueError (
469
477
"tensordict must be given since this env is not stateful"
470
478
)
471
- fen = self . _get_fen ( tensordict ).data
479
+ fen = tensordict . get ( "fen" ).data
472
480
board .set_fen (fen )
473
481
moves = board .legal_moves
474
482
@@ -486,10 +494,10 @@ def _step(self, tensordict):
486
494
fen = None
487
495
if not self .stateful :
488
496
if self .include_fen :
489
- fen = self . _get_fen ( tensordict ).data
497
+ fen = tensordict . get ( "fen" ).data
490
498
board .set_fen (fen )
491
499
elif self .include_pgn :
492
- pgn = self . _get_pgn ( tensordict ).data
500
+ pgn = tensordict . get ( "pgn" ).data
493
501
board = self ._pgn_to_board (pgn , board )
494
502
else :
495
503
raise RuntimeError (
@@ -519,6 +527,15 @@ def _step(self, tensordict):
519
527
if self .include_legal_moves :
520
528
moves_idx = self ._legal_moves_to_index (board = board , pad = True )
521
529
dest .set ("legal_moves" , moves_idx )
530
+ if self .mask_actions :
531
+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
532
+ elif self .mask_actions :
533
+ dest .set (
534
+ "action_mask" ,
535
+ self ._legal_moves_to_index (
536
+ board = self .board , pad = True , return_mask = True
537
+ ),
538
+ )
522
539
523
540
turn = torch .tensor (board .turn )
524
541
done = self ._is_done (board )
@@ -538,11 +555,6 @@ def _step(self, tensordict):
538
555
dest .set ("terminated" , [done ])
539
556
if self .pixels :
540
557
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
541
-
542
- if self .stateful :
543
- mask = self ._legal_moves_to_index (dest , return_mask = True )
544
- self .action_spec .update_mask (mask )
545
-
546
558
return dest
547
559
548
560
def _set_seed (self , * args , ** kwargs ):
0 commit comments