Skip to content

Commit 4102861

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents f25adc9 + b6ce510 commit 4102861

File tree

1 file changed

+58
-37
lines changed

1 file changed

+58
-37
lines changed

torchrl/envs/custom/chess.py

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
7676
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
7777
7878
Examples:
79+
>>> import torch
80+
>>> from torchrl.envs import ChessEnv
81+
>>> _ = torch.manual_seed(0)
7982
>>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
83+
>>> print(env)
84+
TransformedEnv(
85+
env=ChessEnv(),
86+
transform=ActionMask(keys=['action', 'action_mask']))
8087
>>> r = env.reset()
81-
>>> env.rand_step(r)
88+
>>> print(env.rand_step(r))
8289
TensorDict(
8390
fields={
8491
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
92+
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
8593
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
8694
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
8795
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
8896
next: TensorDict(
8997
fields={
98+
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
9099
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
91-
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
100+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
92101
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
93102
pgn: NonTensorData(data=[Event "?"]
94103
[Site "?"]
@@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
97106
[White "?"]
98107
[Black "?"]
99108
[Result "*"]
100-
1. b3 *, batch_size=torch.Size([]), device=None),
109+
110+
1. f4 *, batch_size=torch.Size([]), device=None),
101111
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
102-
san: NonTensorData(data=b3, batch_size=torch.Size([]), device=None),
112+
san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None),
103113
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
104114
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
105115
batch_size=torch.Size([]),
@@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
112122
[White "?"]
113123
[Black "?"]
114124
[Result "*"]
125+
115126
*, batch_size=torch.Size([]), device=None),
116-
san: NonTensorData(data=[SAN][START], batch_size=torch.Size([]), device=None),
127+
san: NonTensorData(data=<start>, batch_size=torch.Size([]), device=None),
117128
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
118129
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
119130
batch_size=torch.Size([]),
120131
device=None,
121132
is_shared=False)
122-
>>> env.rollout(1000)
133+
>>> print(env.rollout(1000))
123134
TensorDict(
124135
fields={
125-
action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
126-
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136+
action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
137+
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
138+
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
127139
fen: NonTensorStack(
128140
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
129-
batch_size=torch.Size([352]),
141+
batch_size=torch.Size([96]),
130142
device=None),
131-
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
143+
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
132144
next: TensorDict(
133145
fields={
134-
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
146+
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
147+
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
135148
fen: NonTensorStack(
136-
['rnbqkbnr/pppppppp/8/8/8/N7/PPPPPPPP/R1BQKBNR b K...,
137-
batch_size=torch.Size([352]),
149+
['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ...,
150+
batch_size=torch.Size([96]),
138151
device=None),
139-
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
152+
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
140153
pgn: NonTensorStack(
141154
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
142-
batch_size=torch.Size([352]),
155+
batch_size=torch.Size([96]),
143156
device=None),
144-
reward: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.float32, is_shared=False),
157+
reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False),
145158
san: NonTensorStack(
146-
['Na3', 'a5', 'Nb1', 'Nc6', 'a3', 'g6', 'd4', 'd6'...,
147-
batch_size=torch.Size([352]),
159+
['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra...,
160+
batch_size=torch.Size([96]),
148161
device=None),
149-
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
150-
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
151-
batch_size=torch.Size([352]),
162+
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
163+
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
164+
batch_size=torch.Size([96]),
152165
device=None,
153166
is_shared=False),
154167
pgn: NonTensorStack(
155168
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
156-
batch_size=torch.Size([352]),
169+
batch_size=torch.Size([96]),
157170
device=None),
158171
san: NonTensorStack(
159-
['[SAN][START]', 'Na3', 'a5', 'Nb1', 'Nc6', 'a3', ...,
160-
batch_size=torch.Size([352]),
172+
['<start>', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',...,
173+
batch_size=torch.Size([96]),
161174
device=None),
162-
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
163-
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
164-
batch_size=torch.Size([352]),
175+
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
176+
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
177+
batch_size=torch.Size([96]),
165178
device=None,
166179
is_shared=False)
167180
""" # noqa: D301
@@ -225,13 +238,15 @@ def _legal_moves_to_index(
225238
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
226239
dtype=torch.int64,
227240
)
228-
241+
mask = None
229242
if return_mask:
230-
return self._move_index_to_mask(indices)
243+
mask = self._move_index_to_mask(indices)
231244
if pad:
232245
indices = torch.nn.functional.pad(
233246
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
234247
)
248+
if return_mask:
249+
return indices, mask
235250
return indices
236251

237252
@classmethod
@@ -369,16 +384,19 @@ def _reset(self, tensordict=None):
369384
dest.set("pgn", pgn)
370385
dest.set("turn", turn)
371386
if self.include_legal_moves:
372-
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
373-
dest.set("legal_moves", moves_idx)
387+
moves_idx = self._legal_moves_to_index(
388+
board=self.board, pad=True, return_mask=self.mask_actions
389+
)
374390
if self.mask_actions:
375-
dest.set("action_mask", self._move_index_to_mask(moves_idx))
391+
moves_idx, mask = moves_idx
392+
dest.set("action_mask", mask)
393+
dest.set("legal_moves", moves_idx)
376394
elif self.mask_actions:
377395
dest.set(
378396
"action_mask",
379397
self._legal_moves_to_index(
380398
board=self.board, pad=True, return_mask=True
381-
),
399+
)[1],
382400
)
383401

384402
if self.pixels:
@@ -525,16 +543,19 @@ def _step(self, tensordict):
525543
dest.set("san", san)
526544

527545
if self.include_legal_moves:
528-
moves_idx = self._legal_moves_to_index(board=board, pad=True)
529-
dest.set("legal_moves", moves_idx)
546+
moves_idx = self._legal_moves_to_index(
547+
board=board, pad=True, return_mask=self.mask_actions
548+
)
530549
if self.mask_actions:
531-
dest.set("action_mask", self._move_index_to_mask(moves_idx))
550+
moves_idx, mask = moves_idx
551+
dest.set("action_mask", mask)
552+
dest.set("legal_moves", moves_idx)
532553
elif self.mask_actions:
533554
dest.set(
534555
"action_mask",
535556
self._legal_moves_to_index(
536557
board=self.board, pad=True, return_mask=True
537-
),
558+
)[1],
538559
)
539560

540561
turn = torch.tensor(board.turn)

0 commit comments

Comments
 (0)