Skip to content

Commit f25adc9

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 056a1a2 + 913b123 commit f25adc9

File tree

4 files changed

+139
-55
lines changed

4 files changed

+139
-55
lines changed

test/test_env.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
from torchrl.envs.transforms.transforms import (
132132
AutoResetEnv,
133133
AutoResetTransform,
134+
Tokenizer,
134135
Transform,
135136
)
136137
from torchrl.envs.utils import (
@@ -3346,10 +3347,6 @@ def test_batched_dynamic(self, break_when_any_done):
33463347
)
33473348
del env_no_buffers
33483349
gc.collect()
3349-
# print(dummy_rollouts)
3350-
# print(rollout_no_buffers_serial)
3351-
# # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3352-
# assert_allclose_td(a, b)
33533350
assert_allclose_td(
33543351
dummy_rollouts.exclude("action"),
33553352
rollout_no_buffers_serial.exclude("action"),
@@ -3463,6 +3460,8 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
34633460
include_hash=include_hash,
34643461
include_san=include_san,
34653462
)
3463+
# Because we always use mask_actions=True
3464+
assert isinstance(env, TransformedEnv)
34663465
check_env_specs(env)
34673466
if include_hash:
34683467
if include_fen:
@@ -3560,8 +3559,8 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
35603559
)
35613560
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
35623561
td = env.reset(TensorDict({"fen": fen}))
3563-
assert td["fen"] == fen
35643562
if include_fen:
3563+
assert td["fen"] == fen
35653564
assert env.board.fen() == fen
35663565
assert td["turn"] == env.lib.WHITE
35673566
assert not td["done"]
@@ -3666,6 +3665,27 @@ def test_reward(
36663665
assert td["reward"] == expected_reward
36673666
assert td["turn"] == (not expected_turn)
36683667

3668+
def test_chess_tokenized(self):
3669+
env = ChessEnv(include_fen=True, stateful=True, include_san=True)
3670+
assert isinstance(env.observation_spec["fen"], NonTensor)
3671+
env = env.append_transform(
3672+
Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"])
3673+
)
3674+
assert isinstance(env.observation_spec["fen"], NonTensor)
3675+
env.transform.transform_output_spec(env.base_env.output_spec)
3676+
env.transform.transform_input_spec(env.base_env.input_spec)
3677+
r = env.rollout(10, return_contiguous=False)
3678+
assert "fen_tokenized" in r
3679+
assert "fen" in r
3680+
assert "fen_tokenized" in r["next"]
3681+
assert "fen" in r["next"]
3682+
ftd = env.fake_tensordict()
3683+
assert "fen_tokenized" in ftd
3684+
assert "fen" in ftd
3685+
assert "fen_tokenized" in ftd["next"]
3686+
assert "fen" in ftd["next"]
3687+
env.check_env_specs()
3688+
36693689

36703690
class TestCustomEnvs:
36713691
def test_tictactoe_env(self):

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
50425042

50435043
def __eq__(self, other):
50445044
return (
5045-
type(self) is type(other)
5045+
type(self) == type(other)
50465046
and self.shape == other.shape
50475047
and self._device == other._device
50485048
and set(self._specs.keys()) == set(other._specs.keys())

torchrl/envs/custom/chess.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
import importlib.util
88
import io
99
import pathlib
10-
from typing import Dict, Optional
10+
from typing import Dict
1111

1212
import torch
1313
from PIL import Image
1414
from tensordict import TensorDict, TensorDictBase
15-
from torchrl.data import Bounded, Categorical, Composite, NonTensor, Unbounded
15+
from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded
1616

1717
from torchrl.envs import EnvBase
1818
from torchrl.envs.common import _EnvPostInit
1919

2020
from torchrl.envs.utils import _classproperty
2121

2222

23-
class _HashMeta(_EnvPostInit):
23+
class _ChessMeta(_EnvPostInit):
2424
def __call__(cls, *args, **kwargs):
2525
instance = super().__call__(*args, **kwargs)
2626
if kwargs.get("include_hash"):
@@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
3737
if instance.include_pgn:
3838
in_keys.append("pgn")
3939
out_keys.append("pgn_hash")
40-
return instance.append_transform(Hash(in_keys, out_keys))
40+
instance = instance.append_transform(Hash(in_keys, out_keys))
41+
if kwargs.get("mask_actions", True):
42+
from torchrl.envs import ActionMask
43+
44+
instance = instance.append_transform(ActionMask())
4145
return instance
4246

4347

44-
class ChessEnv(EnvBase, metaclass=_HashMeta):
48+
class ChessEnv(EnvBase, metaclass=_ChessMeta):
4549
r"""A chess environment that follows the TorchRL API.
4650
4751
This environment simulates a chess game using the `chess` library. It supports various state representations
@@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
6367
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
6468
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
6569
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
70+
mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
71+
to the env to make sure that the actions are properly masked. Default: ``True``.
6672
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
6773
6874
.. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
@@ -200,16 +206,15 @@ def _legal_moves_to_index(
200206
) -> torch.Tensor:
201207
if not self.stateful:
202208
if tensordict is None:
203-
raise RuntimeError(
204-
"rand_action requires a tensordict when stateful is False."
205-
)
206-
if self.include_fen:
207-
fen = self._get_fen(tensordict)
209+
# trust the board
210+
pass
211+
elif self.include_fen:
212+
fen = tensordict.get("fen", None)
208213
fen = fen.data
209214
self.board.set_fen(fen)
210215
board = self.board
211216
elif self.include_pgn:
212-
pgn = self._get_pgn(tensordict)
217+
pgn = tensordict.get("pgn")
213218
pgn = pgn.data
214219
board = self._pgn_to_board(pgn, self.board)
215220

@@ -222,15 +227,19 @@ def _legal_moves_to_index(
222227
)
223228

224229
if return_mask:
225-
return torch.zeros(len(self.san_moves), dtype=torch.bool).index_fill_(
226-
0, indices, True
227-
)
230+
return self._move_index_to_mask(indices)
228231
if pad:
229232
indices = torch.nn.functional.pad(
230233
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
231234
)
232235
return indices
233236

237+
@classmethod
238+
def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor:
239+
return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_(
240+
0, indices, True
241+
)
242+
234243
def __init__(
235244
self,
236245
*,
@@ -240,6 +249,7 @@ def __init__(
240249
include_pgn: bool = False,
241250
include_legal_moves: bool = False,
242251
include_hash: bool = False,
252+
mask_actions: bool = True,
243253
pixels: bool = False,
244254
):
245255
chess = self.lib
@@ -250,6 +260,7 @@ def __init__(
250260
self.include_san = include_san
251261
self.include_fen = include_fen
252262
self.include_pgn = include_pgn
263+
self.mask_actions = mask_actions
253264
self.include_legal_moves = include_legal_moves
254265
if include_legal_moves:
255266
# 218 max possible legal moves per chess board position
@@ -274,8 +285,10 @@ def __init__(
274285

275286
self.stateful = stateful
276287

277-
if not self.stateful:
278-
self.full_state_spec = self.full_observation_spec.clone()
288+
# state_spec is loosely defined as such - it's not really an issue that extra keys
289+
# can go missing but it allows us to reset the env using fen passed to the reset
290+
# method.
291+
self.full_state_spec = self.full_observation_spec.clone()
279292

280293
self.pixels = pixels
281294
if pixels:
@@ -295,16 +308,16 @@ def __init__(
295308
self.full_reward_spec = Composite(
296309
reward=Unbounded(shape=(1,), dtype=torch.float32)
297310
)
311+
if self.mask_actions:
312+
self.full_observation_spec["action_mask"] = Binary(
313+
n=len(self.san_moves), dtype=torch.bool
314+
)
315+
298316
# done spec generated automatically
299317
self.board = chess.Board()
300318
if self.stateful:
301319
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))
302320

303-
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
304-
mask = self._legal_moves_to_index(tensordict, return_mask=True)
305-
self.action_spec.update_mask(mask)
306-
return super().rand_action(tensordict)
307-
308321
def _is_done(self, board):
309322
return board.is_game_over() | board.is_fifty_moves()
310323

@@ -314,11 +327,11 @@ def _reset(self, tensordict=None):
314327
if tensordict is not None:
315328
dest = tensordict.empty()
316329
if self.include_fen:
317-
fen = self._get_fen(tensordict)
330+
fen = tensordict.get("fen", None)
318331
if fen is not None:
319332
fen = fen.data
320333
elif self.include_pgn:
321-
pgn = self._get_pgn(tensordict)
334+
pgn = tensordict.get("pgn", None)
322335
if pgn is not None:
323336
pgn = pgn.data
324337
else:
@@ -358,13 +371,18 @@ def _reset(self, tensordict=None):
358371
if self.include_legal_moves:
359372
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
360373
dest.set("legal_moves", moves_idx)
374+
if self.mask_actions:
375+
dest.set("action_mask", self._move_index_to_mask(moves_idx))
376+
elif self.mask_actions:
377+
dest.set(
378+
"action_mask",
379+
self._legal_moves_to_index(
380+
board=self.board, pad=True, return_mask=True
381+
),
382+
)
383+
361384
if self.pixels:
362385
dest.set("pixels", self._get_tensor_image(board=self.board))
363-
364-
if self.stateful:
365-
mask = self._legal_moves_to_index(dest, return_mask=True)
366-
self.action_spec.update_mask(mask)
367-
368386
return dest
369387

370388
_cairosvg_lib = None
@@ -435,16 +453,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
435453
pgn_string = str(game)
436454
return pgn_string
437455

438-
@classmethod
439-
def _get_fen(cls, tensordict):
440-
fen = tensordict.get("fen", None)
441-
return fen
442-
443-
@classmethod
444-
def _get_pgn(cls, tensordict):
445-
pgn = tensordict.get("pgn", None)
446-
return pgn
447-
448456
def get_legal_moves(self, tensordict=None, uci=False):
449457
"""List the legal moves in a position.
450458
@@ -468,7 +476,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
468476
raise ValueError(
469477
"tensordict must be given since this env is not stateful"
470478
)
471-
fen = self._get_fen(tensordict).data
479+
fen = tensordict.get("fen").data
472480
board.set_fen(fen)
473481
moves = board.legal_moves
474482

@@ -486,10 +494,10 @@ def _step(self, tensordict):
486494
fen = None
487495
if not self.stateful:
488496
if self.include_fen:
489-
fen = self._get_fen(tensordict).data
497+
fen = tensordict.get("fen").data
490498
board.set_fen(fen)
491499
elif self.include_pgn:
492-
pgn = self._get_pgn(tensordict).data
500+
pgn = tensordict.get("pgn").data
493501
board = self._pgn_to_board(pgn, board)
494502
else:
495503
raise RuntimeError(
@@ -519,6 +527,15 @@ def _step(self, tensordict):
519527
if self.include_legal_moves:
520528
moves_idx = self._legal_moves_to_index(board=board, pad=True)
521529
dest.set("legal_moves", moves_idx)
530+
if self.mask_actions:
531+
dest.set("action_mask", self._move_index_to_mask(moves_idx))
532+
elif self.mask_actions:
533+
dest.set(
534+
"action_mask",
535+
self._legal_moves_to_index(
536+
board=self.board, pad=True, return_mask=True
537+
),
538+
)
522539

523540
turn = torch.tensor(board.turn)
524541
done = self._is_done(board)
@@ -538,11 +555,6 @@ def _step(self, tensordict):
538555
dest.set("terminated", [done])
539556
if self.pixels:
540557
dest.set("pixels", self._get_tensor_image(board=self.board))
541-
542-
if self.stateful:
543-
mask = self._legal_moves_to_index(dest, return_mask=True)
544-
self.action_spec.update_mask(mask)
545-
546558
return dest
547559

548560
def _set_seed(self, *args, **kwargs):

0 commit comments

Comments
 (0)