Skip to content

[Feature] Added EXP3 Scoring function in continuation with pr #2358 #3013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ def vertices(
if h in memo and not use_path:
continue
memo.add(h)
r = tree.rollout
if r is not None:
r = r["next", "observation"]
if use_path:
result[cur_path] = tree
elif use_id:
Expand Down
5 changes: 5 additions & 0 deletions torchrl/modules/mcts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .scores import PUCTScore, UCBScore
211 changes: 211 additions & 0 deletions torchrl/modules/mcts/scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import math
from abc import abstractmethod
from enum import Enum

import torch

from tensordict import NestedKey, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from torch import nn


class MCTSScore(TensorDictModuleBase):
@abstractmethod
def forward(self, node):
pass


class PUCTScore(MCTSScore):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some docstrings there!
Should include an example

c: float

def __init__(
self,
*,
c: float,
win_count_key: NestedKey = "win_count",
visits_key: NestedKey = "visits",
total_visits_key: NestedKey = "total_visits",
prior_prob_key: NestedKey = "prior_prob",
score_key: NestedKey = "score",
):
super().__init__()
self.c = c
self.win_count_key = win_count_key
self.visits_key = visits_key
self.total_visits_key = total_visits_key
self.prior_prob_key = prior_prob_key
self.score_key = score_key
self.in_keys = [
self.win_count_key,
self.prior_prob_key,
self.total_visits_key,
self.visits_key,
]
self.out_keys = [self.score_key]

def forward(self, node: TensorDictBase) -> TensorDictBase:
win_count = node.get(self.win_count_key)
visits = node.get(self.visits_key)
n_total = node.get(self.total_visits_key)
prior_prob = node.get(self.prior_prob_key)
node.set(
self.score_key,
(win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits),
)
return node


class UCBScore(MCTSScore):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

c: float

def __init__(
self,
*,
c: float,
win_count_key: NestedKey = "win_count",
visits_key: NestedKey = "visits",
total_visits_key: NestedKey = "total_visits",
score_key: NestedKey = "score",
):
super().__init__()
self.c = c
self.win_count_key = win_count_key
self.visits_key = visits_key
self.total_visits_key = total_visits_key
self.score_key = score_key
self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key]
self.out_keys = [self.score_key]

def forward(self, node: TensorDictBase) -> TensorDictBase:
win_count = node.get(self.win_count_key)
visits = node.get(self.visits_key)
n_total = node.get(self.total_visits_key)
node.set(
self.score_key,
(win_count / visits) + self.c * n_total.sqrt() / (1 + visits),
)
return node


class EXP3Score(MCTSScore):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

def __init__(
self,
*,
gamma: float = 0.1,
weights_key: NestedKey = "weights",
action_prob_key: NestedKey = "action_prob",
reward_key: NestedKey = "reward",
score_key: NestedKey = "score",
num_actions_key: NestedKey = "num_actions",
):
super().__init__()
if not 0 <= gamma <= 1:
raise ValueError(f"gamma must be between 0 and 1, got {gamma}")
self.gamma = gamma
self.weights_key = weights_key
self.action_prob_key = action_prob_key
self.reward_key = reward_key
self.score_key = score_key
self.num_actions_key = num_actions_key

self.in_keys = [self.weights_key, self.num_actions_key]
self.out_keys = [self.score_key]

def forward(self, node: TensorDictBase) -> TensorDictBase:
num_actions = node.get(self.num_actions_key)

if self.weights_key not in node.keys(include_nested=True):
batch_size = node.batch_size
if isinstance(num_actions, torch.Tensor) and num_actions.numel() == 1:
k = int(num_actions.item())
elif isinstance(num_actions, int):
k = num_actions
else:
raise ValueError(
f"'{self.num_actions_key}' ('num_actions') must be an integer or a scalar tensor."
)
weights_shape = (*batch_size, k)
weights = torch.ones(weights_shape, device=node.device)
node.set(self.weights_key, weights)
else:
weights = node.get(self.weights_key)

k = weights.shape[-1]
if isinstance(num_actions, torch.Tensor) and num_actions.numel() == 1:
if k != num_actions.item():
raise ValueError(
f"Shape of weights {weights.shape} implies {k} actions."
f"but num_actions is {num_actions.item()}"
)
elif isinstance(num_actions, int):
if k != num_actions:
raise ValueError(
f"Shape of weights {weights.shape} implies {k} actions, "
f"but num_actions is {num_actions}."
)

sum_weights = torch.sum(weights, dim=-1, keepdim=True)
sum_weights = torch.where(
sum_weights == 0, torch.ones_like(sum_weights), sum_weights
)

p_i = (1 - self.gamma) * (weights / sum_weights) + (self.gamma / k)
node.set(self.score_key, p_i)
if self.action_prob_key != self.score_key:
node.set(self.action_prob_key, p_i)
return node

def update_weights(
self, node: TensorDictBase, action_idx: int, reward: float
) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public methods need docstrings too

if not (0 <= reward <= 1):
ValueError(
f"Reward {reward} is outside the expected [0, 1] range for EXP3."
)

weights = node.get(self.weights_key)
action_probs = node.get(self.score_key)
k = weights.shape[-1]

if weights.ndim == 1:
current_weight = weights[action_idx]
prob_i = action_probs[action_idx]
elif weights.ndim > 1:
current_weight = weights[..., action_idx]
prob_i = action_probs[..., action_idx]
else:
raise ValueError(f"Invalid weights dimensions: {weights.ndim}")

if torch.any(prob_i <= 0):
ValueError(
f"Probability p_i(t) for action {action_idx} is {prob_i}, which is <= 0."
"This might lead to issues in weight update."
)
prob_i = torch.clamp(prob_i, min=1e-9)

reward_tensor = torch.as_tensor(
reward, device=current_weight.device, dtype=current_weight.dtype
)
exponent = (self.gamma / k) * (reward_tensor / prob_i)
new_weight = current_weight * torch.exp(exponent)

if weights.ndim == 1:
weights[action_idx] = new_weight
else:
weights[..., action_idx] = new_weight
node.set(self.weights_key, weights)


class MCTSScores(Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens ?? Any changes needed here ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but I didn't get that 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add docstrings to each of these classes & public methods :)

PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value
UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002
UCB1_TUNED = "UCB1-Tuned"
EXP3 = "EXP3"
PUCT_VARIANT = "PUCT-Variant"
Loading