Skip to content

Commit 6e645e3

Browse files
committed
[DRAFT, Example] Add MCTS example
ghstack-source-id: 15144df Pull Request resolved: pytorch#2796
1 parent a2879d0 commit 6e645e3

File tree

3 files changed

+215
-2
lines changed

3 files changed

+215
-2
lines changed

examples/trees/mcts.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torchrl
8+
from tensordict import TensorDict
9+
import time
10+
11+
start_time = time.time()
12+
13+
pgn_or_fen = "fen"
14+
mask_actions = True
15+
16+
env = torchrl.envs.ChessEnv(
17+
include_pgn=False,
18+
include_fen=True,
19+
include_hash=True,
20+
include_hash_inv=True,
21+
include_san=True,
22+
stateful=True,
23+
mask_actions=mask_actions,
24+
)
25+
26+
27+
def transform_reward(td):
28+
if "reward" not in td:
29+
return td
30+
reward = td["reward"]
31+
if reward == 0.5:
32+
td["reward"] = 0
33+
elif reward == 1 and td["turn"]:
34+
td["reward"] = -td["reward"]
35+
return td
36+
37+
38+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
39+
# Need to transform the reward to be:
40+
# white win = 1
41+
# draw = 0
42+
# black win = -1
43+
env = env.append_transform(transform_reward)
44+
45+
forest = torchrl.data.MCTSForest()
46+
forest.reward_keys = env.reward_keys
47+
forest.done_keys = env.done_keys
48+
forest.action_keys = env.action_keys
49+
50+
if mask_actions:
51+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
52+
else:
53+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]
54+
55+
C = 2.0**0.5
56+
57+
58+
def traversal_priority_UCB1(tree):
59+
subtree = tree.subtree
60+
visits = subtree.visits
61+
reward_sum = subtree.wins
62+
63+
# If it's black's turn, flip the reward, since black wants to
64+
# optimize for the lowest reward, not highest.
65+
if not subtree.rollout[0, 0]["turn"]:
66+
reward_sum = -reward_sum
67+
68+
parent_visits = tree.visits
69+
reward_sum = reward_sum.squeeze(-1)
70+
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
71+
priority[visits == 0] = float("inf")
72+
return priority
73+
74+
75+
def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
76+
done = False
77+
trees_visited = [tree]
78+
79+
while not done:
80+
if tree.subtree is None:
81+
td_tree = tree.rollout[-1]["next"].clone()
82+
83+
if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
84+
actions = env.all_actions(td_tree)
85+
subtrees = []
86+
87+
for action in actions:
88+
td = env.step(env.reset(td_tree).update(action))
89+
new_node = torchrl.data.Tree(
90+
rollout=td.unsqueeze(0),
91+
node_data=td["next"].select(*forest.node_map.in_keys),
92+
count=torch.tensor(0),
93+
wins=torch.zeros_like(td["next"]["reward"]),
94+
)
95+
subtrees.append(new_node)
96+
97+
# NOTE: This whole script runs about 2x faster with lazy stack
98+
# versus eager stack.
99+
tree.subtree = TensorDict.lazy_stack(subtrees)
100+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
101+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
102+
103+
else:
104+
rollout_state = td_tree
105+
106+
if rollout_state["done"]:
107+
rollout_reward = rollout_state["reward"]
108+
else:
109+
rollout = env.rollout(
110+
max_steps=max_rollout_steps,
111+
tensordict=rollout_state,
112+
)
113+
rollout_reward = rollout[-1]["next", "reward"]
114+
done = True
115+
116+
else:
117+
priorities = traversal_priority_UCB1(tree)
118+
chosen_idx = torch.argmax(priorities).item()
119+
tree = tree.subtree[chosen_idx]
120+
trees_visited.append(tree)
121+
122+
for tree in trees_visited:
123+
tree.visits += 1
124+
tree.wins += rollout_reward
125+
126+
127+
def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):
128+
"""Performs Monte-Carlo tree search in an environment.
129+
130+
Args:
131+
forest (MCTSForest): Forest of the tree to update. If the tree does not
132+
exist yet, it is added.
133+
root (TensorDict): The root step of the tree to update.
134+
env (EnvBase): Environment to performs actions in.
135+
num_steps (int): Number of iterations to traverse.
136+
max_rollout_steps (int): Maximum number of steps for each rollout.
137+
"""
138+
if root not in forest:
139+
for action in env.all_actions(root):
140+
td = env.step(env.reset(root.clone()).update(action))
141+
forest.extend(td.unsqueeze(0))
142+
143+
tree = forest.get_tree(root)
144+
tree.wins = torch.zeros_like(td["next", "reward"])
145+
for subtree in tree.subtree:
146+
subtree.wins = torch.zeros_like(td["next", "reward"])
147+
148+
for _ in range(num_steps):
149+
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)
150+
151+
return tree
152+
153+
154+
def tree_format_fn(tree):
155+
td = tree.rollout[-1]["next"]
156+
return [
157+
td["san"],
158+
td[pgn_or_fen].split("\n")[-1],
159+
tree.wins,
160+
tree.visits,
161+
]
162+
163+
164+
def get_best_move(fen, mcts_steps, rollout_steps):
165+
root = env.reset(TensorDict({"fen": fen}))
166+
tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)
167+
moves = []
168+
169+
for subtree in tree.subtree:
170+
san = subtree.rollout[0]["next", "san"]
171+
reward_sum = subtree.wins
172+
visits = subtree.visits
173+
value_avg = (reward_sum / visits).item()
174+
if not subtree.rollout[0]["turn"]:
175+
value_avg = -value_avg
176+
moves.append((value_avg, san))
177+
178+
moves = sorted(moves, key=lambda x: -x[0])
179+
180+
print("------------------")
181+
for value_avg, san in moves:
182+
print(f" {value_avg:0.02f} {san}")
183+
print("------------------")
184+
185+
return moves[0][1]
186+
187+
188+
# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
189+
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
190+
assert get_best_move(fen0, 100, 10) == "Rd8#"
191+
192+
# Black has M1, best move Qg6#. Other moves give rough equality or worse.
193+
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
194+
assert get_best_move(fen1, 100, 10) == "Qg6#"
195+
196+
# White has M2, best move Rxg8+. Any other move loses.
197+
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
198+
assert get_best_move(fen2, 1000, 10) == "Rxg8+"
199+
200+
end_time = time.time()
201+
total_time = end_time - start_time
202+
203+
print(f"Took {total_time} s")

torchrl/data/map/tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,11 @@ def valid_paths(cls, tree: Tree):
13641364
def __len__(self):
13651365
return len(self.data_map)
13661366

1367+
def __contains__(self, root: TensorDictBase):
1368+
if self.node_map is None:
1369+
return False
1370+
return root.select(*self.node_map.in_keys) in self.node_map
1371+
13671372
def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
13681373
"""Generates a string representation of a tree in the forest.
13691374

torchrl/envs/custom/chess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,15 @@ def lib(cls):
220220
return chess
221221

222222
_san_moves = []
223+
_san_move_to_index_map = {}
223224

224225
@_classproperty
225226
def san_moves(cls):
226227
if not cls._san_moves:
227228
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
228229
cls._san_moves.extend(f.read().split("\n"))
230+
for idx, san_move in enumerate(cls._san_moves):
231+
cls._san_move_to_index_map[san_move] = idx
229232
return cls._san_moves
230233

231234
def _legal_moves_to_index(
@@ -253,7 +256,7 @@ def _legal_moves_to_index(
253256
board = self.board
254257

255258
indices = torch.tensor(
256-
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
259+
[self._san_move_to_index_map[board.san(m)] for m in board.legal_moves],
257260
dtype=torch.int64,
258261
)
259262
mask = None
@@ -411,7 +414,9 @@ def _reset(self, tensordict=None):
411414
if move is None:
412415
dest.set("san", "<start>")
413416
else:
414-
dest.set("san", self.board.san(move))
417+
prev_board = self.board.copy()
418+
prev_board.pop()
419+
dest.set("san", prev_board.san(move))
415420
if self.include_fen:
416421
dest.set("fen", fen)
417422
if self.include_pgn:

0 commit comments

Comments
 (0)