1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torchrl
8
+ from tensordict import TensorDict
9
+ import time
10
+
11
+ start_time = time .time ()
12
+
13
+ pgn_or_fen = "fen"
14
+ mask_actions = True
15
+
16
+ env = torchrl .envs .ChessEnv (
17
+ include_pgn = False ,
18
+ include_fen = True ,
19
+ include_hash = True ,
20
+ include_hash_inv = True ,
21
+ include_san = True ,
22
+ stateful = True ,
23
+ mask_actions = mask_actions ,
24
+ )
25
+
26
+
27
+ def transform_reward (td ):
28
+ if "reward" not in td :
29
+ return td
30
+ reward = td ["reward" ]
31
+ if reward == 0.5 :
32
+ td ["reward" ] = 0
33
+ elif reward == 1 and td ["turn" ]:
34
+ td ["reward" ] = - td ["reward" ]
35
+ return td
36
+
37
+
38
+ # ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
39
+ # Need to transform the reward to be:
40
+ # white win = 1
41
+ # draw = 0
42
+ # black win = -1
43
+ env = env .append_transform (transform_reward )
44
+
45
+ forest = torchrl .data .MCTSForest ()
46
+ forest .reward_keys = env .reward_keys
47
+ forest .done_keys = env .done_keys
48
+ forest .action_keys = env .action_keys
49
+
50
+ if mask_actions :
51
+ forest .observation_keys = [f"{ pgn_or_fen } _hash" , "turn" , "action_mask" ]
52
+ else :
53
+ forest .observation_keys = [f"{ pgn_or_fen } _hash" , "turn" ]
54
+
55
+ C = 2.0 ** 0.5
56
+
57
+
58
+ def traversal_priority_UCB1 (tree ):
59
+ subtree = tree .subtree
60
+ visits = subtree .visits
61
+ reward_sum = subtree .wins
62
+
63
+ # If it's black's turn, flip the reward, since black wants to
64
+ # optimize for the lowest reward, not highest.
65
+ if not subtree .rollout [0 , 0 ]["turn" ]:
66
+ reward_sum = - reward_sum
67
+
68
+ parent_visits = tree .visits
69
+ reward_sum = reward_sum .squeeze (- 1 )
70
+ priority = (reward_sum + C * torch .sqrt (torch .log (parent_visits ))) / visits
71
+ priority [visits == 0 ] = float ("inf" )
72
+ return priority
73
+
74
+
75
+ def _traverse_MCTS_one_step (forest , tree , env , max_rollout_steps ):
76
+ done = False
77
+ trees_visited = [tree ]
78
+
79
+ while not done :
80
+ if tree .subtree is None :
81
+ td_tree = tree .rollout [- 1 ]["next" ].clone ()
82
+
83
+ if (tree .visits > 0 or tree .parent is None ) and not td_tree ["done" ]:
84
+ actions = env .all_actions (td_tree )
85
+ subtrees = []
86
+
87
+ for action in actions :
88
+ td = env .step (env .reset (td_tree ).update (action ))
89
+ new_node = torchrl .data .Tree (
90
+ rollout = td .unsqueeze (0 ),
91
+ node_data = td ["next" ].select (* forest .node_map .in_keys ),
92
+ count = torch .tensor (0 ),
93
+ wins = torch .zeros_like (td ["next" ]["reward" ]),
94
+ )
95
+ subtrees .append (new_node )
96
+
97
+ # NOTE: This whole script runs about 2x faster with lazy stack
98
+ # versus eager stack.
99
+ tree .subtree = TensorDict .lazy_stack (subtrees )
100
+ chosen_idx = torch .randint (0 , len (subtrees ), ()).item ()
101
+ rollout_state = subtrees [chosen_idx ].rollout [- 1 ]["next" ]
102
+
103
+ else :
104
+ rollout_state = td_tree
105
+
106
+ if rollout_state ["done" ]:
107
+ rollout_reward = rollout_state ["reward" ]
108
+ else :
109
+ rollout = env .rollout (
110
+ max_steps = max_rollout_steps ,
111
+ tensordict = rollout_state ,
112
+ )
113
+ rollout_reward = rollout [- 1 ]["next" , "reward" ]
114
+ done = True
115
+
116
+ else :
117
+ priorities = traversal_priority_UCB1 (tree )
118
+ chosen_idx = torch .argmax (priorities ).item ()
119
+ tree = tree .subtree [chosen_idx ]
120
+ trees_visited .append (tree )
121
+
122
+ for tree in trees_visited :
123
+ tree .visits += 1
124
+ tree .wins += rollout_reward
125
+
126
+
127
+ def traverse_MCTS (forest , root , env , num_steps , max_rollout_steps ):
128
+ """Performs Monte-Carlo tree search in an environment.
129
+
130
+ Args:
131
+ forest (MCTSForest): Forest of the tree to update. If the tree does not
132
+ exist yet, it is added.
133
+ root (TensorDict): The root step of the tree to update.
134
+ env (EnvBase): Environment to performs actions in.
135
+ num_steps (int): Number of iterations to traverse.
136
+ max_rollout_steps (int): Maximum number of steps for each rollout.
137
+ """
138
+ if root not in forest :
139
+ for action in env .all_actions (root ):
140
+ td = env .step (env .reset (root .clone ()).update (action ))
141
+ forest .extend (td .unsqueeze (0 ))
142
+
143
+ tree = forest .get_tree (root )
144
+ tree .wins = torch .zeros_like (td ["next" , "reward" ])
145
+ for subtree in tree .subtree :
146
+ subtree .wins = torch .zeros_like (td ["next" , "reward" ])
147
+
148
+ for _ in range (num_steps ):
149
+ _traverse_MCTS_one_step (forest , tree , env , max_rollout_steps )
150
+
151
+ return tree
152
+
153
+
154
+ def tree_format_fn (tree ):
155
+ td = tree .rollout [- 1 ]["next" ]
156
+ return [
157
+ td ["san" ],
158
+ td [pgn_or_fen ].split ("\n " )[- 1 ],
159
+ tree .wins ,
160
+ tree .visits ,
161
+ ]
162
+
163
+
164
+ def get_best_move (fen , mcts_steps , rollout_steps ):
165
+ root = env .reset (TensorDict ({"fen" : fen }))
166
+ tree = traverse_MCTS (forest , root , env , mcts_steps , rollout_steps )
167
+ moves = []
168
+
169
+ for subtree in tree .subtree :
170
+ san = subtree .rollout [0 ]["next" , "san" ]
171
+ reward_sum = subtree .wins
172
+ visits = subtree .visits
173
+ value_avg = (reward_sum / visits ).item ()
174
+ if not subtree .rollout [0 ]["turn" ]:
175
+ value_avg = - value_avg
176
+ moves .append ((value_avg , san ))
177
+
178
+ moves = sorted (moves , key = lambda x : - x [0 ])
179
+
180
+ print ("------------------" )
181
+ for value_avg , san in moves :
182
+ print (f" { value_avg :0.02f} { san } " )
183
+ print ("------------------" )
184
+
185
+ return moves [0 ][1 ]
186
+
187
+
188
+ # White has M1, best move Rd8#. Any other moves lose to M2 or M1.
189
+ fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
190
+ assert get_best_move (fen0 , 100 , 10 ) == "Rd8#"
191
+
192
+ # Black has M1, best move Qg6#. Other moves give rough equality or worse.
193
+ fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
194
+ assert get_best_move (fen1 , 100 , 10 ) == "Qg6#"
195
+
196
+ # White has M2, best move Rxg8+. Any other move loses.
197
+ fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
198
+ assert get_best_move (fen2 , 1000 , 10 ) == "Rxg8+"
199
+
200
+ end_time = time .time ()
201
+ total_time = end_time - start_time
202
+
203
+ print (f"Took { total_time } s" )
0 commit comments