diff --git a/test/test_mcts.py b/test/test_mcts.py new file mode 100644 index 00000000000..5ad4467a875 --- /dev/null +++ b/test/test_mcts.py @@ -0,0 +1,847 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import pytest +import torch +from tensordict import TensorDict +from torchrl.modules.mcts.scores import EXP3Score, PUCTScore, UCBScore + +# Sample TensorDict for testing +def create_node( + num_actions, weights=None, batch_size=None, device="cpu", custom_keys=None +): + if custom_keys is None: + custom_keys = { + "num_actions_key": "num_actions", + "weights_key": "weights", + "score_key": "score", + } + + if batch_size: + data = { + custom_keys["num_actions_key"]: torch.tensor( + [num_actions] * batch_size, device=device + ) + } + if weights is not None: + if weights.ndim == 1: + weights = weights.unsqueeze(0).repeat(batch_size, 1) + data[custom_keys["weights_key"]] = weights.to(device) + td = TensorDict(data, batch_size=[batch_size], device=device) + else: + data = { + custom_keys["num_actions_key"]: torch.tensor(num_actions, device=device) + } + if weights is not None: + data[custom_keys["weights_key"]] = weights.to(device) + td = TensorDict(data, batch_size=[], device=device) + return td + + +# Sample TensorDict node for UCBScore +def create_ucb_node( + win_count, visits, total_visits, batch_size=None, device="cpu", custom_keys=None +): + if custom_keys is None: + custom_keys = { + "win_count_key": "win_count", + "visits_key": "visits", + "total_visits_key": "total_visits", + "score_key": "score", + } + + win_count = torch.as_tensor(win_count, device=device, dtype=torch.float32) + visits = torch.as_tensor(visits, device=device, dtype=torch.float32) + total_visits = torch.as_tensor(total_visits, device=device, dtype=torch.float32) + + if batch_size: + if win_count.ndim == 0: + win_count = win_count.unsqueeze(0).repeat(batch_size) + elif win_count.shape[0] != batch_size: + raise ValueError("Batch size mismatch for win_count") + if visits.ndim == 0: + visits = visits.unsqueeze(0).repeat(batch_size) + elif visits.shape[0] != batch_size: + raise ValueError("Batch size mismatch for visits") + if total_visits.ndim == 0: + total_visits = total_visits.unsqueeze(0).repeat(batch_size) + elif total_visits.shape[0] != batch_size and total_visits.numel() != 1: + raise ValueError("Batch size mismatch for total_visits") + if total_visits.numel() == 1 and batch_size > 1: + total_visits = total_visits.repeat(batch_size) + + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + } + td = TensorDict( + data, + batch_size=[batch_s for batch_s in batch_size] + if isinstance(batch_size, (list, tuple)) + else [batch_size], + device=device, + ) + else: + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + } + td = TensorDict( + data, + batch_size=win_count.shape[:-1] if win_count.ndim > 1 else [], + device=device, + ) + + return td + + +# Helper function to create a sample TensorDict node for PUCTScore +def create_puct_node( + win_count, + visits, + total_visits, + prior_prob, + batch_size=None, + device="cpu", + custom_keys=None, +): + if custom_keys is None: + custom_keys = { + "win_count_key": "win_count", + "visits_key": "visits", + "total_visits_key": "total_visits", + "prior_prob_key": "prior_prob", + "score_key": "score", + } + + win_count = torch.as_tensor(win_count, device=device, dtype=torch.float32) + visits = torch.as_tensor(visits, device=device, dtype=torch.float32) + total_visits = torch.as_tensor(total_visits, device=device, dtype=torch.float32) + prior_prob = torch.as_tensor(prior_prob, device=device, dtype=torch.float32) + + if batch_size: + if win_count.ndim == 0: + win_count = win_count.unsqueeze(0).repeat(batch_size) + elif win_count.shape[0] != batch_size: + raise ValueError("Batch size mismatch for win_count") + if visits.ndim == 0: + visits = visits.unsqueeze(0).repeat(batch_size) + elif visits.shape[0] != batch_size: + raise ValueError("Batch size mismatch for visits") + if prior_prob.ndim == 0: + prior_prob = prior_prob.unsqueeze(0).repeat(batch_size) + elif prior_prob.shape[0] != batch_size: + raise ValueError("Batch size mismatch for prior_prob") + + if ( + total_visits.numel() == 1 and batch_size > 1 + ): # scalar total_visits for batch + total_visits = total_visits.repeat(batch_size) + elif total_visits.ndim == 0: + total_visits = total_visits.unsqueeze(0).repeat( + batch_size + ) # make it (batch_size,) + elif total_visits.shape[0] != batch_size: + raise ValueError("Batch size mismatch for total_visits") + + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + custom_keys["prior_prob_key"]: prior_prob, + } + if isinstance(batch_size, (list, tuple)): + td_batch_size = batch_size + else: + td_batch_size = [batch_size] + td = TensorDict(data, batch_size=td_batch_size, device=device) + + else: + data = { + custom_keys["win_count_key"]: win_count, + custom_keys["visits_key"]: visits, + custom_keys["total_visits_key"]: total_visits, + custom_keys["prior_prob_key"]: prior_prob, + } + td_batch_size = win_count.shape[:-1] if win_count.ndim > 1 else [] + + td = TensorDict(data, batch_size=td_batch_size, device=device) + + return td + + +class TestEXP3Score: + @pytest.fixture + def default_scorer(self): + return EXP3Score() + + @pytest.fixture + def custom_key_names(self): + return { + "weights_key": "custom_weights", + "score_key": "custom_scores", + "num_actions_key": "custom_num_actions", + "action_prob_key": "custom_actions_prob", + "reward_key": "custom_reward", + } + + @pytest.mark.parametrize("gamma_val", [0.1, 0.5, 0.9]) + def test_initialization(self, gamma_val): + scorer = EXP3Score(gamma=gamma_val) + assert scorer.gamma == gamma_val + scorer_default = EXP3Score() + assert scorer_default.gamma == 0.1 + + def test_forward_initial_weights(self, default_scorer): + num_actions = 3 + node = create_node(num_actions=num_actions) + + default_scorer.forward(node) + + assert default_scorer.weights_key in node.keys() + expected_weights = torch.ones(num_actions) + torch.testing.assert_close( + node.get(default_scorer.weights_key), expected_weights + ) + + expected_scores = torch.ones(num_actions) / num_actions + torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores) + torch.testing.assert_close( + node.get(default_scorer.score_key).sum(), torch.tensor(1.0) + ) + + def test_forward_custom_weights(self, default_scorer): + num_actions = 3 + weights = torch.tensor([1.0, 2.0, 3.0]) + node = create_node(num_actions=num_actions, weights=weights) + + default_scorer.forward(node) + + gamma = default_scorer.gamma + sum_w = weights.sum() + expected_scores = (1 - gamma) * (weights / sum_w) + (gamma / num_actions) + + torch.testing.assert_close(node.get(default_scorer.score_key), expected_scores) + torch.testing.assert_close( + node.get(default_scorer.score_key).sum(), torch.tensor(1.0) + ) + + @pytest.mark.parametrize("batch_s", [2, 4]) + def test_forward_batch(self, default_scorer, batch_s): + num_actions = 3 + node_initial = create_node(num_actions=num_actions, batch_size=batch_s) + default_scorer.forward(node_initial) + + expected_weights_initial = torch.ones(batch_s, num_actions) + torch.testing.assert_close( + node_initial.get(default_scorer.weights_key), expected_weights_initial + ) + + expected_scores_initial = torch.ones(batch_s, num_actions) / num_actions + torch.testing.assert_close( + node_initial.get(default_scorer.score_key), expected_scores_initial + ) + torch.testing.assert_close( + node_initial.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s) + ) + + weights_custom = torch.rand(batch_s, num_actions) + 0.1 + node_custom = create_node( + num_actions=num_actions, weights=weights_custom, batch_size=batch_s + ) + default_scorer.forward(node_custom) + + gamma = default_scorer.gamma + sum_w_custom = weights_custom.sum(dim=-1, keepdim=True) + expected_scores_custom = (1 - gamma) * (weights_custom / sum_w_custom) + ( + gamma / num_actions + ) + torch.testing.assert_close( + node_custom.get(default_scorer.score_key), + expected_scores_custom, + atol=1e-6, + rtol=1e-6, + ) + torch.testing.assert_close( + node_custom.get(default_scorer.score_key).sum(dim=-1), torch.ones(batch_s) + ) + + def test_update_weights_single_node(self, default_scorer): + num_actions = 3 + action_idx = 0 + reward = 1.0 + node = create_node(num_actions=num_actions) + + default_scorer.forward(node) + initial_weights = node.get(default_scorer.weights_key).clone() + prob_i = node.get(default_scorer.score_key)[action_idx] + + default_scorer.update_weights(node, action_idx, reward) + + updated_weights = node.get(default_scorer.weights_key) + gamma = default_scorer.gamma + k = num_actions + + expected_new_weight_val = initial_weights[action_idx] * math.exp( + (gamma / k) * (reward / prob_i) + ) + + torch.testing.assert_close( + updated_weights[action_idx], torch.tensor(expected_new_weight_val) + ) + torch.testing.assert_close( + updated_weights[action_idx + 1 :], initial_weights[action_idx + 1 :] + ) + + default_scorer.forward(node) + sum_w_updated = updated_weights.sum() + expected_scores_after_update = (1 - gamma) * ( + updated_weights / sum_w_updated + ) + (gamma / k) + torch.testing.assert_close( + node.get(default_scorer.score_key), expected_scores_after_update + ) + + def test_update_weights_zero_reward(self, default_scorer): + num_actions = 3 + action_idx = 1 + reward = 0.0 + weights = torch.tensor([1.0, 2.0, 1.5]) + node = create_node(num_actions=num_actions, weights=weights) + + default_scorer.forward(node) + initial_weights = node.get(default_scorer.weights_key).clone() + prob_i = node.get(default_scorer.score_key)[action_idx] + + default_scorer.update_weights(node, action_idx, reward) + updated_weights = node.get(default_scorer.weights_key) + gamma = default_scorer.gamma + k = num_actions + + expected_new_weight_val = initial_weights[action_idx] * math.exp( + (gamma / k) * (reward / prob_i) + ) + torch.testing.assert_close(updated_weights[action_idx], expected_new_weight_val) + torch.testing.assert_close( + updated_weights[action_idx], initial_weights[action_idx] + ) + + @pytest.mark.parametrize("batch_s", [2, 3]) + def test_update_weights_batch(self, default_scorer, batch_s): + num_actions = 3 + node = create_node(num_actions=num_actions, batch_size=batch_s) + default_scorer.forward(node) + + initial_weights_batch = node.get(default_scorer.weights_key).clone() + probs_batch = node.get(default_scorer.score_key).clone() + + rewards = torch.rand(batch_s) + action_indices = torch.randint(0, num_actions, (batch_s,)) + + expected_updated_weights_batch = initial_weights_batch.clone() + gamma = default_scorer.gamma + k = num_actions + + for i in range(batch_s): + action_idx = action_indices[i].item() + reward = rewards[i].item() + + single_node_td = node[i] + + current_weight_item = initial_weights_batch[i, action_idx] + prob_i_item = probs_batch[i, action_idx] + + exp_val = math.exp((gamma / k) * (reward / prob_i_item)) + expected_updated_weights_batch[i, action_idx] = ( + current_weight_item * exp_val + ) + + node_item_to_update = node[i : i + 1] + default_scorer.update_weights(node_item_to_update, action_idx, reward) + + torch.testing.assert_close( + node.get(default_scorer.weights_key), + expected_updated_weights_batch, + atol=1e-5, + rtol=1e-5, + ) + + def test_single_action(self, default_scorer): + num_actions = 1 + node = create_node(num_actions=num_actions) + default_scorer.forward(node) + + assert default_scorer.weights_key in node.keys() + torch.testing.assert_close( + node.get(default_scorer.weights_key), torch.ones(num_actions) + ) + torch.testing.assert_close( + node.get(default_scorer.score_key), torch.ones(num_actions) + ) # p_i = 1.0 + + action_idx = 0 + reward = 0.5 + initial_weights = node.get(default_scorer.weights_key).clone() + prob_i = node.get(default_scorer.score_key)[action_idx] + + default_scorer.update_weights(node, action_idx, reward) + updated_weights = node.get(default_scorer.weights_key) + gamma = default_scorer.gamma + k = num_actions + + expected_new_weight_val = initial_weights[action_idx] * math.exp( + (gamma / k) * (reward / prob_i) + ) + torch.testing.assert_close( + updated_weights[action_idx], torch.tensor(expected_new_weight_val) + ) + + @pytest.mark.parametrize( + "gamma_val, expected_behavior", [(0.0, "exploitation"), (1.0, "exploration")] + ) + def test_gamma_extremes(self, gamma_val, expected_behavior): + scorer = EXP3Score(gamma=gamma_val) + num_actions = 3 + weights = torch.tensor([1.0, 2.0, 7.0]) + node = create_node(num_actions=num_actions, weights=weights) + + scorer.forward(node) + scores = node.get(scorer.score_key) + + if expected_behavior == "exploitation": + expected_scores = weights / weights.sum() + torch.testing.assert_close(scores, expected_scores) + elif expected_behavior == "exploration": + expected_scores = torch.ones(num_actions) / num_actions + torch.testing.assert_close(scores, expected_scores) + + def test_custom_keys(self, custom_key_names): + gamma = 0.2 + scorer = EXP3Score( + gamma=gamma, + weights_key=custom_key_names["weights_key"], + score_key=custom_key_names["score_key"], + num_actions_key=custom_key_names["num_actions_key"], + action_prob_key=custom_key_names["action_prob_key"], + ) + num_actions = 2 + + node1 = create_node(num_actions=num_actions, custom_keys=custom_key_names) + scorer.forward(node1) + + assert custom_key_names["weights_key"] in node1.keys() + expected_weights1 = torch.ones(num_actions) + torch.testing.assert_close( + node1.get(custom_key_names["weights_key"]), expected_weights1 + ) + expected_scores1 = torch.ones(num_actions) / num_actions + torch.testing.assert_close( + node1.get(custom_key_names["score_key"]), expected_scores1 + ) + if ( + scorer.action_prob_key != scorer.score_key + ): # Check if action_prob_key was also populated + torch.testing.assert_close( + node1.get(custom_key_names["action_prob_key"]), expected_scores1 + ) + + weights2_val = torch.tensor([1.0, 3.0]) + node2 = create_node( + num_actions=num_actions, weights=weights2_val, custom_keys=custom_key_names + ) + scorer.forward(node2) + + sum_w2 = weights2_val.sum() + expected_scores2 = (1 - gamma) * (weights2_val / sum_w2) + (gamma / num_actions) + torch.testing.assert_close( + node2.get(custom_key_names["score_key"]), expected_scores2 + ) + + action_idx = 0 + reward = 1.0 + initial_weights2 = node2.get(custom_key_names["weights_key"]).clone() + prob_i2 = node2.get(custom_key_names["score_key"])[action_idx] + + scorer.update_weights(node2, action_idx, reward) + updated_weights2 = node2.get(custom_key_names["weights_key"]) + k = num_actions + + expected_new_weight_val2 = initial_weights2[action_idx] * math.exp( + (gamma / k) * (reward / prob_i2) + ) + torch.testing.assert_close( + updated_weights2[action_idx], torch.tensor(expected_new_weight_val2) + ) + + def test_forward_raises_error_on_mismatched_num_actions(self, default_scorer): + num_actions_prop = 3 + weights = torch.tensor([1.0, 2.0, 3.0, 4.0]) # K=4 from weights + node = create_node( + num_actions=num_actions_prop, weights=weights + ) # num_actions=3 + + with pytest.raises( + ValueError, + match="Shape of weights .* implies 4 actions, but num_actions is 3", + ): + default_scorer.forward(node) + + weights_ok = torch.tensor([1.0, 2.0, 3.0]) + node_ok = create_node( + num_actions=torch.tensor(4), weights=weights_ok + ) # num_actions=4 from tensor + + with pytest.raises( + ValueError, + match="Shape of weights .* implies 3 actions, but num_actions is 4", + ): + default_scorer.forward(node_ok) + + def test_update_weights_handles_prob_zero(self, default_scorer): + num_actions = 2 + action_idx = 0 + reward = 1.0 + scorer_exploit = EXP3Score(gamma=0.0) + weights = torch.tensor([0.0, 1.0]) + node = create_node(num_actions=num_actions, weights=weights) + + scorer_exploit.forward(node) # p_0 will be 0 + assert node.get(scorer_exploit.score_key)[0] == 0.0 + + with pytest.warns( + UserWarning, match="Probability p_i\\(t\\) for action 0 is 0.0" + ): + scorer_exploit.update_weights(node, action_idx, reward) + torch.testing.assert_close( + node.get(scorer_exploit.weights_key)[action_idx], torch.tensor(0.0) + ) + + def test_init_raises_error_gamma_out_of_range(self): + with pytest.raises(ValueError, match="gamma must be between 0 and 1"): + EXP3Score(gamma=-0.1) + with pytest.raises(ValueError, match="gamma must be between 0 and 1"): + EXP3Score(gamma=1.1) + + def test_update_weights_reward_warning(self, default_scorer): + num_actions = 2 + node = create_node(num_actions=num_actions) + default_scorer.forward(node) + with pytest.warns( + UserWarning, match="Reward .* is outside the expected \\[0,1\\] range" + ): + default_scorer.update_weights(node, 0, 1.5) + with pytest.warns( + UserWarning, match="Reward .* is outside the expected \\[0,1\\] range" + ): + default_scorer.update_weights(node, 0, -0.5) + initial_weight = node.get(default_scorer.weights_key)[0].clone() + default_scorer.update_weights(node, 0, 1.5) + assert node.get(default_scorer.weights_key)[0] != initial_weight # it changed + + +class TestUCBScore: + @pytest.fixture + def default_ucb_scorer(self): + return UCBScore(c=math.sqrt(2)) + + @pytest.fixture + def ucb_custom_key_names(self): + return { + "win_count_key": "custom_wins", + "visits_key": "custom_visits", + "total_visits_key": "custom_total_visits", + "score_key": "custom_ucb_score", + } + + @pytest.mark.parametrize("c_val", [0.5, 1.0, math.sqrt(2), 5.0]) + def test_initialization(self, c_val): + scorer = UCBScore(c=c_val) + assert scorer.c == c_val + + def test_forward_basic(self, default_ucb_scorer): + win_count = torch.tensor([10.0, 5.0, 20.0]) + visits = torch.tensor([15.0, 10.0, 25.0]) + total_visits_parent = torch.tensor(50.0) + + node = create_ucb_node( + win_count=win_count, visits=visits, total_visits=total_visits_parent + ) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + exploitation_term = win_count / visits + exploration_term = c * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close( + node.get(default_ucb_scorer.score_key), expected_scores + ) + + def test_forward_zero_visits(self, default_ucb_scorer): + win_count = torch.tensor([0.0, 0.0]) + visits = torch.tensor([10.0, 0.0]) + total_visits_parent = torch.tensor(10.0) + + node = create_ucb_node( + win_count=win_count, visits=visits, total_visits=total_visits_parent + ) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + scores = node.get(default_ucb_scorer.score_key) + + expected_score_0 = ( + win_count[0] / visits[0] + ) + c * total_visits_parent.sqrt() / (1 + visits[0]) + torch.testing.assert_close(scores[0], expected_score_0) + assert torch.isnan( + scores[1] + ), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." + + @pytest.mark.parametrize("batch_s", [2, 3]) + def test_forward_batch(self, default_ucb_scorer, batch_s): + win_count = torch.rand(batch_s, 2) * 10 + visits = torch.rand(batch_s, 2) * 5 + 1 + total_visits_parent = torch.rand(batch_s) * 20 + float(batch_s) + + node = create_ucb_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + batch_size=batch_s, + ) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + exploitation_term = win_count / visits + exploration_term = c * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close( + node.get(default_ucb_scorer.score_key), expected_scores + ) + + def test_forward_exploration_term(self, default_ucb_scorer): + win_count = torch.tensor([0.0, 0.0, 0.0]) + visits = torch.tensor([10.0, 5.0, 1.0]) + total_visits_parent = torch.tensor(100.0) + + node = create_ucb_node( + win_count=win_count, visits=visits, total_visits=total_visits_parent + ) + default_ucb_scorer.forward(node) + + c = default_ucb_scorer.c + expected_scores = c * total_visits_parent.sqrt() / (1 + visits) + + torch.testing.assert_close( + node.get(default_ucb_scorer.score_key), expected_scores + ) + + def test_custom_keys(self, ucb_custom_key_names): + c_val = 1.5 + scorer = UCBScore( + c=c_val, + win_count_key=ucb_custom_key_names["win_count_key"], + visits_key=ucb_custom_key_names["visits_key"], + total_visits_key=ucb_custom_key_names["total_visits_key"], + score_key=ucb_custom_key_names["score_key"], + ) + + win_count = torch.tensor([1.0, 2.0]) + visits = torch.tensor([3.0, 4.0]) + total_visits_parent = torch.tensor(10.0) + + node = create_ucb_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + custom_keys=ucb_custom_key_names, + ) + scorer.forward(node) + + exploitation = win_count / visits + exploration = c_val * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation + exploration + + assert ucb_custom_key_names["score_key"] in node.keys() + torch.testing.assert_close( + node.get(ucb_custom_key_names["score_key"]), expected_scores + ) + + assert "score" not in node.keys() + assert "win_count" not in node.keys() + assert "visits" not in node.keys() + assert "total_visits" not in node.keys() + + +class TestPUCTScore: + @pytest.fixture + def default_puct_scorer(self): + return PUCTScore(c=5.0) + + @pytest.fixture + def puct_custom_key_names(self): + return { + "win_count_key": "custom_puct_wins", + "visits_key": "custom_puct_visits", + "total_visits_key": "custom_puct_total_visits", + "prior_prob_key": "custom_puct_priors", + "score_key": "custom_puct_score", + } + + @pytest.mark.parametrize("c_val", [0.5, 1.0, 5.0, 10.0]) + def test_initialization(self, c_val): + scorer = PUCTScore(c=c_val) + assert scorer.c == c_val + + def test_forward_basic(self, default_puct_scorer): + win_count = torch.tensor([10.0, 5.0, 20.0]) + visits = torch.tensor([15.0, 10.0, 25.0]) + prior_prob = torch.tensor([0.4, 0.3, 0.3]) + total_visits_parent = torch.tensor(50.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + exploitation_term = win_count / visits + exploration_term = c * prior_prob * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close( + node.get(default_puct_scorer.score_key), expected_scores + ) + + def test_forward_zero_visits(self, default_puct_scorer): + win_count = torch.tensor([0.0, 0.0]) + visits = torch.tensor([10.0, 0.0]) + prior_prob = torch.tensor([0.6, 0.4]) + total_visits_parent = torch.tensor(10.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + scores = node.get(default_puct_scorer.score_key) + + expected_score_0 = (win_count[0] / visits[0]) + c * prior_prob[ + 0 + ] * total_visits_parent.sqrt() / (1 + visits[0]) + torch.testing.assert_close(scores[0], expected_score_0) + + assert torch.isnan( + scores[1] + ), "Score for unvisited action (0 visits, 0 wins) should be NaN due to 0/0, unless handled." + + @pytest.mark.parametrize("batch_s", [2, 3]) + def test_forward_batch(self, default_puct_scorer, batch_s): + num_actions = 2 + win_count = torch.rand(batch_s, num_actions) * 10 + visits = torch.rand(batch_s, num_actions) * 5 + 1 + prior_prob = torch.rand(batch_s, num_actions) + prior_prob = prior_prob / prior_prob.sum(dim=-1, keepdim=True) + total_visits_parent = torch.rand(batch_s) * 20 + float(batch_s) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + batch_size=batch_s, + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + exploitation_term = win_count / visits + exploration_term = ( + c * prior_prob * total_visits_parent.unsqueeze(-1).sqrt() / (1 + visits) + ) + expected_scores = exploitation_term + exploration_term + + torch.testing.assert_close( + node.get(default_puct_scorer.score_key), + expected_scores, + atol=1e-6, + rtol=1e-6, + ) + + def test_forward_exploration_term(self, default_puct_scorer): + num_actions = 3 + win_count = torch.zeros(num_actions) + visits = torch.tensor([10.0, 5.0, 1.0]) + prior_prob = torch.tensor([0.3, 0.5, 0.2]) + total_visits_parent = torch.tensor(100.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + ) + default_puct_scorer.forward(node) + + c = default_puct_scorer.c + # exploitation_term is effectively 0 + expected_scores = c * prior_prob * total_visits_parent.sqrt() / (1 + visits) + + torch.testing.assert_close( + node.get(default_puct_scorer.score_key), expected_scores + ) + + def test_custom_keys(self, puct_custom_key_names): + c_val = 2.5 + scorer = PUCTScore( + c=c_val, + win_count_key=puct_custom_key_names["win_count_key"], + visits_key=puct_custom_key_names["visits_key"], + total_visits_key=puct_custom_key_names["total_visits_key"], + prior_prob_key=puct_custom_key_names["prior_prob_key"], + score_key=puct_custom_key_names["score_key"], + ) + + win_count = torch.tensor([1.0, 2.0]) + visits = torch.tensor([3.0, 4.0]) + prior_prob = torch.tensor([0.5, 0.5]) + total_visits_parent = torch.tensor(10.0) + + node = create_puct_node( + win_count=win_count, + visits=visits, + total_visits=total_visits_parent, + prior_prob=prior_prob, + custom_keys=puct_custom_key_names, + ) + scorer.forward(node) + + exploitation = win_count / visits + exploration = c_val * prior_prob * total_visits_parent.sqrt() / (1 + visits) + expected_scores = exploitation + exploration + + assert puct_custom_key_names["score_key"] in node.keys() + torch.testing.assert_close( + node.get(puct_custom_key_names["score_key"]), expected_scores + ) + + # Check that default keys are not present + assert "score" not in node.keys() + assert "win_count" not in node.keys() + assert "visits" not in node.keys() + assert "total_visits" not in node.keys() + assert "prior_prob" not in node.keys() diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..fd3f84913ee 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -163,9 +163,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: diff --git a/torchrl/modules/mcts/__init__.py b/torchrl/modules/mcts/__init__.py new file mode 100644 index 00000000000..b225f4c0cca --- /dev/null +++ b/torchrl/modules/mcts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .scores import EXP3Score, PUCTScore, UCBScore diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py new file mode 100644 index 00000000000..2b9de222f62 --- /dev/null +++ b/torchrl/modules/mcts/scores.py @@ -0,0 +1,513 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import math +from abc import abstractmethod +from enum import Enum + +import torch + +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn + + +class MCTSScore(TensorDictModuleBase): + @abstractmethod + def forward(self, node): + pass + + +class PUCTScore(MCTSScore): + """Computes the PUCT (Polynomial Upper Confidence Trees) score for MCTS. + + PUCT is a widely used score in MCTS algorithms, notably in AlphaGo and AlphaZero, + to balance exploration and exploitation. It incorporates prior probabilities from a + policy network, encouraging exploration of actions deemed promising by the policy, + while also considering visit counts and accumulated rewards. + + The formula used is: + `score = (win_count / visits) + c * prior_prob * sqrt(total_visits) / (1 + visits)` + + Where: + - `win_count`: Sum of rewards (or win counts) for the action. + - `visits`: Visit count for the action. + - `total_visits`: Visit count of the parent node (N). + - `prior_prob`: Prior probability of selecting the action (e.g., from a policy network). + - `c`: The exploration constant, controlling the trade-off between exploitation + (first term) and exploration (second term). + + Args: + c (float): The exploration constant. + win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase` + containing the sum of rewards (or win counts) for each action. + Defaults to "win_count". + visits_key (NestedKey, optional): Key for the tensor containing the visit + count for each action. Defaults to "visits". + total_visits_key (NestedKey, optional): Key for the tensor (or scalar) + representing the visit count of the parent node (N). Defaults to "total_visits". + prior_prob_key (NestedKey, optional): Key for the tensor containing the + prior probabilities for each action. Defaults to "prior_prob". + score_key (NestedKey, optional): Key where the calculated PUCT scores + will be stored in the output `TensorDictBase`. Defaults to "score". + + Input Keys: + - `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions) + or matching `visits_key`. + - `visits_key` (torch.Tensor): Tensor of shape (..., num_actions). If an action + has zero visits, its exploitation term (win_count / visits) will result in NaN + if win_count is also zero, or +/-inf if win_count is non-zero. The exploration + term will still be valid due to `(1 + visits)`. + - `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs, + representing the parent node's visit count. + - `prior_prob_key` (torch.Tensor): Tensor of shape (..., num_actions) containing + prior probabilities. + + Output Keys: + - `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing + the calculated PUCT scores. + """ + + c: float + + def __init__( + self, + *, + c: float, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + prior_prob_key: NestedKey = "prior_prob", + score_key: NestedKey = "score", + ): + super().__init__() + self.c = c + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.prior_prob_key = prior_prob_key + self.score_key = score_key + self.in_keys = [ + self.win_count_key, + self.prior_prob_key, + self.total_visits_key, + self.visits_key, + ] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + win_count = node.get(self.win_count_key) + visits = node.get(self.visits_key) + n_total = node.get(self.total_visits_key) + prior_prob = node.get(self.prior_prob_key) + node.set( + self.score_key, + (win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits), + ) + return node + + +class UCBScore(MCTSScore): + """Computes the UCB (Upper Confidence Bound) score, specifically UCB1, for MCTS. + + UCB1 is a classic algorithm for the multi-armed bandit problem that balances + exploration and exploitation. In MCTS, it's used to select which action to + explore from a given node. The score encourages trying actions with high + empirical rewards and actions that have been visited less frequently. + + The formula used is: + `score = (win_count / visits) + c * sqrt(log(total_visits) / visits)` + However, the implementation here uses `1 + visits` in the denominator of the + exploration term to handle cases where `visits` might be zero for an action, + preventing division by zero and ensuring unvisited actions get a high exploration score. + The formula implemented is: + `score = (win_count / visits) + c * sqrt(total_visits) / (1 + visits)` + Note: The standard UCB1 formula's exploration term is `c * sqrt(log(N) / N_i)`, + where N is parent visits and N_i is action visits. This implementation uses `sqrt(N)` + instead of `sqrt(log N)`. For the canonical UCB1 `sqrt(log N / N_i)` term, + total_visits would need to be `log(parent_visits)` and then use `c * sqrt(total_visits / visits_i)`. + The current form is simpler and common in some MCTS variants. + + Args: + c (float): The exploration constant. A common value is `sqrt(2)`. + win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase` + containing the sum of rewards (or win counts) for each action. + Defaults to "win_count". + visits_key (NestedKey, optional): Key for the tensor containing the visit + count for each action. Defaults to "visits". + total_visits_key (NestedKey, optional): Key for the tensor (or scalar) + representing the visit count of the parent node (N). This is used in the + exploration term. Defaults to "total_visits". + score_key (NestedKey, optional): Key where the calculated UCB scores + will be stored in the output `TensorDictBase`. Defaults to "score". + + Input Keys: + - `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions). + - `visits_key` (torch.Tensor): Tensor of shape (..., num_actions). If an action + has zero visits, its exploitation term (win_count / visits) will result in NaN + if win_count is also zero, or +/-inf if win_count is non-zero. The exploration + term remains well-defined due to `(1 + visits)`. + - `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs, + representing the parent node's visit count (N). + + Output Keys: + - `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing + the calculated UCB scores. + """ + + c: float + + def __init__( + self, + *, + c: float, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + score_key: NestedKey = "score", + ): + super().__init__() + self.c = c + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.score_key = score_key + self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + win_count = node.get(self.win_count_key) + visits = node.get(self.visits_key) + n_total = node.get(self.total_visits_key) + node.set( + self.score_key, + (win_count / visits) + self.c * n_total.sqrt() / (1 + visits), + ) + return node + + +class EXP3Score(MCTSScore): + """Computes action selection probabilities for the EXP3 algorithm in MCTS. + + EXP3 (Exponential-weight algorithm for Exploration and Exploitation) is a bandit + algorithm that performs well in adversarial or non-stationary environments. + It maintains weights for each action and adjusts them based on received rewards. + Actions are chosen probabilistically based on these weights, with a mechanism + to ensure a minimum level of exploration. + + The `forward` method calculates the probability distribution over actions: + `p_i(t) = (1 - gamma) * (w_i(t) / sum_weights) + (gamma / K)` + where `w_i(t)` are the current weights, `sum_weights` is the sum of all weights, + `gamma` is the exploration factor, and `K` is the number of actions. + These probabilities are typically stored in the `score_key` and used for action selection. + + The `update_weights` method updates the weights after an action is chosen and a + reward is observed. This method is typically called after a simulation/rollout + and backpropagation phase in MCTS. The update rule is: + `w_i(t+1) = w_i(t) * exp((gamma / K) * (reward / p_i(t)))` + where `reward` is the reward for the chosen action (typically normalized to [0,1]) + and `p_i(t)` is the probability with which the action was chosen. + + Reference: "Bandit based Monte-Carlo Planning" (Kocsis & Szepesvari, 2006), though + the specific EXP3 formulation can vary (e.g., "Regret Analysis of Stochastic and + Nonstochastic Multi-armed Bandit Problems", Bubeck & Cesa-Bianchi, 2012 for EXP3 details). + + Args: + gamma (float, optional): Exploration factor, balancing uniform exploration + and exploitation of current weights. Must be in [0, 1]. Defaults to 0.1. + weights_key (NestedKey, optional): Key in the input `TensorDictBase` for + the tensor containing current action weights. If not found during the first + `forward` call, weights are initialized to ones. Defaults to "weights". + action_prob_key (NestedKey, optional): Key to store the calculated action + probabilities `p_i(t)`. If different from `score_key`, it allows storing + these probabilities separately, which might be useful if `score_key` is + used for a different purpose by the selection strategy. Defaults to "action_prob". + The `update_weights` method will look for `p_i(t)` in `score_key`. + score_key (NestedKey, optional): Key where the calculated action probabilities + (scores for MCTS selection) will be stored. Defaults to "score". + num_actions_key (NestedKey, optional): Key for the number of available + actions (K). Used for weight initialization and in formulas. Defaults to "num_actions". + + Input Keys for `forward`: + - `weights_key` (torch.Tensor): Tensor of shape (..., num_actions) containing + current weights. Initialized to ones if not present on first call. + - `num_actions_key` (int or torch.Tensor): Scalar representing K, the number of actions. + + Output Keys for `forward`: + - `score_key` (torch.Tensor): Tensor of shape (..., num_actions) containing + the calculated action probabilities `p_i(t)`. + - `action_prob_key` (torch.Tensor, optional): Same as `score_key` if this key + is set and different from `score_key`. + + `update_weights` Method: + This method is designed to be called externally after an action has been + selected (using probabilities from `forward`) and a reward obtained. + Args for `update_weights(node: TensorDictBase, action_idx: int, reward: float)`: + - `node`: The `TensorDictBase` for the current MCTS node, containing + at least `weights_key` and `score_key` (with `p_i(t)` values). + - `action_idx`: The index of the action that was chosen. + - `reward`: The reward received for the chosen action (assumed to be in [0,1]). + """ + + def __init__( + self, + *, + gamma: float = 0.1, + weights_key: NestedKey = "weights", + action_prob_key: NestedKey = "action_prob", + reward_key: NestedKey = "reward", + score_key: NestedKey = "score", + num_actions_key: NestedKey = "num_actions", + ): + super().__init__() + if not 0 <= gamma <= 1: + raise ValueError(f"gamma must be between 0 and 1, got {gamma}") + self.gamma = gamma + self.weights_key = weights_key + self.action_prob_key = action_prob_key + self.reward_key = reward_key + self.score_key = score_key + self.num_actions_key = num_actions_key + + self.in_keys = [self.weights_key, self.num_actions_key] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + num_actions = node.get(self.num_actions_key) + + if self.weights_key not in node.keys(include_nested=True): + batch_size = node.batch_size + if isinstance(num_actions, torch.Tensor) and num_actions.numel() == 1: + k = int(num_actions.item()) + elif isinstance(num_actions, int): + k = num_actions + else: + raise ValueError( + f"'{self.num_actions_key}' ('num_actions') must be an integer or a scalar tensor." + ) + weights_shape = (*batch_size, k) + weights = torch.ones(weights_shape, device=node.device) + node.set(self.weights_key, weights) + else: + weights = node.get(self.weights_key) + + k = weights.shape[-1] + if isinstance(num_actions, torch.Tensor) and num_actions.numel() == 1: + if k != num_actions.item(): + raise ValueError( + f"Shape of weights {weights.shape} implies {k} actions." + f"but num_actions is {num_actions.item()}" + ) + elif isinstance(num_actions, int): + if k != num_actions: + raise ValueError( + f"Shape of weights {weights.shape} implies {k} actions, " + f"but num_actions is {num_actions}." + ) + + sum_weights = torch.sum(weights, dim=-1, keepdim=True) + sum_weights = torch.where( + sum_weights == 0, torch.ones_like(sum_weights), sum_weights + ) + + p_i = (1 - self.gamma) * (weights / sum_weights) + (self.gamma / k) + node.set(self.score_key, p_i) + if self.action_prob_key != self.score_key: + node.set(self.action_prob_key, p_i) + return node + + def update_weights( + self, node: TensorDictBase, action_idx: int, reward: float + ) -> None: + """Updates the weight of the chosen action based on the reward. + + w_i(t+1) = w_i(t) * exp((gamma / K) * (reward / p_i(t))) + Assumes reward is in [0, 1]. + """ + if not (0 <= reward <= 1): + ValueError( + f"Reward {reward} is outside the expected [0, 1] range for EXP3." + ) + + weights = node.get(self.weights_key) + action_probs = node.get(self.score_key) + k = weights.shape[-1] + + if weights.ndim == 1: + current_weight = weights[action_idx] + prob_i = action_probs[action_idx] + elif weights.ndim > 1: + current_weight = weights[..., action_idx] + prob_i = action_probs[..., action_idx] + else: + raise ValueError(f"Invalid weights dimensions: {weights.ndim}") + + if torch.any(prob_i <= 0): + ValueError( + f"Probability p_i(t) for action {action_idx} is {prob_i}, which is <= 0." + "This might lead to issues in weight update." + ) + prob_i = torch.clamp(prob_i, min=1e-9) + + reward_tensor = torch.as_tensor( + reward, device=current_weight.device, dtype=current_weight.dtype + ) + exponent = (self.gamma / k) * (reward_tensor / prob_i) + new_weight = current_weight * torch.exp(exponent) + + if weights.ndim == 1: + weights[action_idx] = new_weight + else: + weights[..., action_idx] = new_weight + node.set(self.weights_key, weights) + + +class UCB1TunedScore(MCTSScore): + """Computes the UCB1-Tuned score for MCTS, using variance estimation. + + UCB1-Tuned is an enhancement of the UCB1 algorithm that incorporates an estimate + of the variance of rewards for each action. This allows for a more refined + balance between exploration and exploitation, potentially leading to better + performance, especially when reward variances differ significantly across actions. + + The score for an action `i` is calculated as: + `score_i = avg_reward_i + sqrt(log(N) / N_i * min(0.25, V_i))` + + The variance estimate `V_i` for action `i` is calculated as: + `V_i = (sum_squared_rewards_i / N_i) - avg_reward_i^2 + sqrt(exploration_constant * log(N) / N_i)` + + Where: + - `avg_reward_i`: Average reward obtained from action `i`. + - `N_i`: Number of times action `i` has been visited. + - `N`: Total number of times the parent node has been visited. + - `sum_squared_rewards_i`: Sum of the squares of rewards received from action `i`. + - `exploration_constant`: A constant used in the bias correction term of `V_i`. + Auer et al. (2002) suggest a value of 2.0 for rewards in the range [0,1]. + - The term `min(0.25, V_i)` implies that rewards are scaled to `[0,1]`, as 0.25 is + the maximum variance for a distribution in this range (e.g., Bernoulli(0.5)). + + Reference: "Finite-time Analysis of the Multiarmed Bandit Problem" + (Auer, Cesa-Bianchi, Fischer, 2002). + + Args: + exploration_constant (float, optional): The constant `C` used in the bias + correction term for the variance estimate `V_i`. Defaults to `2.0`, + as suggested for rewards in `[0,1]`. + win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase` + containing the sum of rewards for each action (Q_i * N_i). Defaults to "win_count". + visits_key (NestedKey, optional): Key for the tensor containing the visit + count for each action (N_i). Defaults to "visits". + total_visits_key (NestedKey, optional): Key for the tensor (or scalar) + representing the visit count of the parent node (N). Defaults to "total_visits". + sum_squared_rewards_key (NestedKey, optional): Key for the tensor containing + the sum of squared rewards received for each action. This is crucial for + calculating the empirical variance. Defaults to "sum_squared_rewards". + score_key (NestedKey, optional): Key where the calculated UCB1-Tuned scores + will be stored in the output `TensorDictBase`. Defaults to "score". + + Input Keys: + - `win_count_key` (torch.Tensor): Sum of rewards for each action. + - `visits_key` (torch.Tensor): Visit counts for each action (N_i). + - `total_visits_key` (torch.Tensor): Parent node's visit count (N). + - `sum_squared_rewards_key` (torch.Tensor): Sum of squared rewards for each action. + + Output Keys: + - `score_key` (torch.Tensor): Calculated UCB1-Tuned scores for each action. + + Important Notes: + - **Unvisited Nodes**: Actions with zero visits (`visits_key` is 0) are assigned a + very large positive score to ensure they are selected for exploration. + - **Reward Range**: The `min(0.25, V_i)` term is theoretically most sound when + rewards are normalized to the range `[0, 1]`. + - **Logarithm of N**: `log(N)` (log of parent visits) is calculated using `torch.log(torch.clamp(N, min=1.0))` + to prevent issues with `N=0` or `N` between 0 and 1. + """ + + def __init__( + self, + *, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + sum_squared_rewards_key: NestedKey = "sum_squared_rewards", + score_key: NestedKey = "score", + exploration_constant: float = 2.0, + ): + super().__init__() + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.sum_squared_rewards_key = sum_squared_rewards_key + self.score_key = score_key + self.exploration_constant = exploration_constant + + self.in_keys = [ + self.win_count_key, + self.visits_key, + self.total_visits_key, + self.sum_squared_rewards_key, + ] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + q_sum_i = node.get(self.win_count_key) + n_i = node.get(self.visits_key) + n_parent = node.get(self.total_visits_key) + sum_sq_rewards_i = node.get(self.sum_squared_rewards_key) + + if n_parent.ndim > 0 and n_parent.ndim < q_sum_i.ndim: + n_parent_expanded = n_parent.unsqueeze(-1) + else: + n_parent_expanded = n_parent + + safe_n_parent_for_log = torch.clamp(n_parent_expanded, min=1.0) + log_n_parent = torch.log(safe_n_parent_for_log) + + scores = torch.zeros_like(q_sum_i, device=q_sum_i.device) + + visited_mask = n_i > 0 + + if torch.any(visited_mask): + q_sum_i_v = q_sum_i[visited_mask] + n_i_v = n_i[visited_mask] + sum_sq_rewards_i_v = sum_sq_rewards_i[visited_mask] + + log_n_parent_v = log_n_parent.expand_as(n_i)[visited_mask] + + avg_reward_i_v = q_sum_i_v / n_i_v + + empirical_variance_v = (sum_sq_rewards_i_v / n_i_v) - avg_reward_i_v.pow(2) + bias_correction_v = ( + self.exploration_constant * log_n_parent_v / n_i_v + ).sqrt() + + v_i_v = empirical_variance_v + bias_correction_v + v_i_v.clamp(min=0) + + min_variance_term_v = torch.min(torch.full_like(v_i_v, 0.25), v_i_v) + exploration_component_v = ( + log_n_parent_v / n_i_v * min_variance_term_v + ).sqrt() + + scores[visited_mask] = avg_reward_i_v + exploration_component_v + + unvisited_mask = ~visited_mask + if torch.any(unvisited_mask): + scores[unvisited_mask] = torch.finfo(scores.dtype).max / 10.0 + + node.set(self.score_key, scores) + return node + + +class MCTSScores(Enum): + PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value + UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 + UCB1_TUNED = functools.partial( + UCB1TunedScore, exploration_constant=2.0 + ) # Auer et al. (2002) C=2 for rewards in [0,1] + EXP3 = functools.partial(EXP3Score, gamma=0.1) + PUCT_VARIANT = "PUCT-Variant"