Skip to content

Commit 17983d4

Browse files
author
Vincent Moens
committed
[Feature] ChessEnv
ghstack-source-id: 087c3b1 Pull Request resolved: #2641
1 parent 6c7d233 commit 17983d4

File tree

4 files changed

+200
-1
lines changed

4 files changed

+200
-1
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ TorchRL offers a series of custom built-in environments.
345345
:toctree: generated/
346346
:template: rl_template.rst
347347

348+
ChessEnv
348349
PendulumEnv
349350
TicTacToeEnv
350351
LLMHashingEnv

torchrl/envs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .batched_envs import ParallelEnv, SerialEnv
77
from .common import EnvBase, EnvMetaData, make_tensordict
8-
from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv
8+
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
99
from .env_creator import env_creator, EnvCreator, get_env_metadata
1010
from .gym_like import default_info_dict_reader, GymLikeEnv
1111
from .libs import (

torchrl/envs/custom/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .chess import ChessEnv
67
from .llm import LLMHashingEnv
78
from .pendulum import PendulumEnv
89
from .tictactoeenv import TicTacToeEnv

torchrl/envs/custom/chess.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Dict, Optional
6+
7+
import torch
8+
from tensordict import TensorDict, TensorDictBase
9+
from torchrl.data import Categorical, Composite, NonTensor, Unbounded
10+
11+
from torchrl.envs import EnvBase
12+
13+
from torchrl.envs.utils import _classproperty
14+
15+
16+
class ChessEnv(EnvBase):
17+
"""A chess environment that follows the TorchRL API.
18+
19+
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
20+
21+
Args:
22+
stateful (bool): Whether to keep track of the internal state of the board.
23+
If False, the state will be stored in the observation and passed back
24+
to the environment on each call. Default: ``False``.
25+
26+
.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
27+
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
28+
valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for
29+
this behavior.
30+
31+
Examples:
32+
>>> env = ChessEnv()
33+
>>> r = env.reset()
34+
>>> env.rand_step(r)
35+
TensorDict(
36+
fields={
37+
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
38+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
39+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
40+
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
41+
next: TensorDict(
42+
fields={
43+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
44+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None),
45+
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
46+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False),
47+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
48+
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
49+
batch_size=torch.Size([]),
50+
device=None,
51+
is_shared=False),
52+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
53+
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
54+
batch_size=torch.Size([]),
55+
device=None,
56+
is_shared=False)
57+
>>> env.rollout(1000)
58+
TensorDict(
59+
fields={
60+
action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
61+
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
62+
fen: NonTensorStack(
63+
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
64+
batch_size=torch.Size([322]),
65+
device=None),
66+
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
67+
next: TensorDict(
68+
fields={
69+
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
70+
fen: NonTensorStack(
71+
['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ...,
72+
batch_size=torch.Size([322]),
73+
device=None),
74+
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
75+
reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False),
76+
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
77+
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
78+
batch_size=torch.Size([322]),
79+
device=None,
80+
is_shared=False),
81+
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
82+
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
83+
batch_size=torch.Size([322]),
84+
device=None,
85+
is_shared=False)
86+
87+
88+
"""
89+
90+
_hash_table: Dict[int, str] = {}
91+
92+
@_classproperty
93+
def lib(cls):
94+
try:
95+
import chess
96+
except ImportError:
97+
raise ImportError(
98+
"The `chess` library could not be found. Make sure you installed it through `pip install chess`."
99+
)
100+
return chess
101+
102+
def __init__(self, stateful: bool = False):
103+
chess = self.lib
104+
super().__init__()
105+
self.full_observation_spec = Composite(
106+
hashing=Unbounded(shape=(), dtype=torch.int64),
107+
fen=NonTensor(shape=()),
108+
turn=Categorical(n=2, dtype=torch.bool, shape=()),
109+
)
110+
self.stateful = stateful
111+
if not self.stateful:
112+
self.full_state_spec = self.full_observation_spec.clone()
113+
self.full_action_spec = Composite(
114+
action=Categorical(n=-1, shape=(), dtype=torch.int64)
115+
)
116+
self.full_reward_spec = Composite(
117+
reward=Unbounded(shape=(1,), dtype=torch.int32)
118+
)
119+
# done spec generated automatically
120+
self.board = chess.Board()
121+
if self.stateful:
122+
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))
123+
124+
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
125+
self._set_action_space(tensordict)
126+
return super().rand_action(tensordict)
127+
128+
def _reset(self, tensordict=None):
129+
fen = None
130+
if tensordict is not None:
131+
fen = self._get_fen(tensordict)
132+
dest = tensordict.empty()
133+
else:
134+
dest = TensorDict()
135+
136+
if fen is None:
137+
self.board.reset()
138+
fen = self.board.fen()
139+
else:
140+
self.board.set_fen(fen.data)
141+
142+
hashing = hash(fen)
143+
144+
self._set_action_space()
145+
turn = self.board.turn
146+
return dest.set("fen", fen).set("hashing", hashing).set("turn", turn)
147+
148+
def _set_action_space(self, tensordict: TensorDict | None = None):
149+
if not self.stateful and tensordict is not None:
150+
fen = self._get_fen(tensordict).data
151+
self.board.set_fen(fen)
152+
self.action_spec.set_provisional_n(self.board.legal_moves.count())
153+
154+
@classmethod
155+
def _get_fen(cls, tensordict):
156+
fen = tensordict.get("fen", None)
157+
if fen is None:
158+
hashing = tensordict.get("hashing", None)
159+
if hashing is not None:
160+
fen = cls._hash_table.get(hashing.item())
161+
return fen
162+
163+
def _step(self, tensordict):
164+
# action
165+
action = tensordict.get("action")
166+
board = self.board
167+
if not self.stateful:
168+
fen = self._get_fen(tensordict).data
169+
board.set_fen(fen)
170+
action = str(list(board.legal_moves)[action])
171+
# assert chess.Move.from_uci(action) in board.legal_moves
172+
board.push_san(action)
173+
self._set_action_space()
174+
175+
# Collect data
176+
fen = self.board.fen()
177+
dest = tensordict.empty()
178+
hashing = hash(fen)
179+
dest.set("fen", fen)
180+
dest.set("hashing", hashing)
181+
182+
done = board.is_checkmate()
183+
turn = torch.tensor(board.turn)
184+
reward = torch.tensor([done]).int() * (turn.int() * 2 - 1)
185+
done = done | board.is_stalemate() | board.is_game_over()
186+
dest.set("reward", reward)
187+
dest.set("turn", turn)
188+
dest.set("done", [done])
189+
dest.set("terminated", [done])
190+
return dest
191+
192+
def _set_seed(self, *args, **kwargs):
193+
...
194+
195+
def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
196+
self._set_action_space(tensordict)
197+
return self.action_spec.cardinality()

0 commit comments

Comments
 (0)