Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 34aea90

Browse files
authored
Add MCCFR (#90)
* add outcome_sampling_mccfr * add esmccfr * update README.md
1 parent 7829fa5 commit 34aea90

File tree

5 files changed

+198
-2
lines changed

5 files changed

+198
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ This project aims to provide some implementations of the most typical reinforcem
2525
- DDPG
2626
- TD3
2727
- SAC
28-
- CFR
28+
- CFR/OS-MCCFR/ES-MCCFR
2929
- Minimax
3030

3131
If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).

src/algorithms/cfr/cfr.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
include("tabular_cfr.jl")
2+
include("outcome_sampling_mccfr.jl")
3+
include("external_sampling_mccfr.jl")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
export ExternalSamplingMCCFRPolicy
2+
3+
using Random
4+
using StatsBase: sample, Weights
5+
6+
"""
7+
ExternalSamplingMCCFRPolicy
8+
9+
This implementation uses stochasticaly-weighted averaging.
10+
11+
Ref:
12+
13+
- [MONTE CARLO SAMPLING AND REGRET MINIMIZATION FOR EQUILIBRIUM COMPUTATION AND DECISION-MAKING IN LARGE EXTENSIVE FORM GAMES](http://mlanctot.info/files/papers/PhD_Thesis_MarcLanctot.pdf)
14+
- [Monte Carlo Sampling for Regret Minimization in Extensive Games](https://papers.nips.cc/paper/3713-monte-carlo-sampling-for-regret-minimization-in-extensive-games.pdf)
15+
"""
16+
struct ExternalSamplingMCCFRPolicy{S,T,R<:AbstractRNG} <: AbstractPolicy
17+
nodes::Dict{S,InfoStateNode}
18+
behavior_policy::QBasedPolicy{TabularLearner{S,T},WeightedExplorer{true,R}}
19+
end
20+
21+
(p::ExternalSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
22+
23+
RLBase.get_prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_policy, env)
24+
25+
function ExternalSamplingMCCFRPolicy(;
26+
env::AbstractEnv,
27+
n_iter::Int,
28+
rng=Random.GLOBAL_RNG,
29+
)
30+
@assert NumAgentStyle(env) isa MultiAgent
31+
@assert DynamicStyle(env) === SEQUENTIAL
32+
@assert RewardStyle(env) === TERMINAL_REWARD
33+
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
34+
@assert DefaultStateStyle(env) === Information{String}()
35+
36+
nodes = init_info_state_nodes(env)
37+
38+
for i in 1:n_iter
39+
for p in get_players(env)
40+
if p != get_chance_player(env)
41+
external_sampling(copy(env), p, nodes, rng)
42+
end
43+
end
44+
end
45+
46+
behavior_policy = QBasedPolicy(;
47+
learner = TabularLearner{String}(),
48+
explorer = WeightedExplorer(; is_normalized = true, rng = rng),
49+
)
50+
51+
for (k, v) in nodes
52+
s = sum(v.cumulative_strategy)
53+
if s != 0
54+
update!(behavior_policy, k => v.cumulative_strategy ./ s)
55+
end
56+
end
57+
58+
ExternalSamplingMCCFRPolicy(nodes, behavior_policy)
59+
end
60+
61+
function external_sampling(env, i, nodes, rng)
62+
current_player = get_current_player(env)
63+
64+
if get_terminal(env)
65+
get_reward(env, i)
66+
elseif current_player == get_chance_player(env)
67+
env(rand(rng, get_actions(env)))
68+
external_sampling(env, i, nodes, rng)
69+
else
70+
I = get_state(env)
71+
node = nodes[I]
72+
regret_matching!(node)
73+
σ, rI, sI = node.strategy, node.cumulative_regret, node.cumulative_strategy
74+
n = length(node.strategy)
75+
76+
if i == current_player
77+
u = zeros(n)
78+
= 0
79+
for (aᵢ, a) in enumerate(get_legal_actions(env))
80+
u[aᵢ] = external_sampling(child(env, a), i, nodes, rng)
81+
+= σ[aᵢ] * u[aᵢ]
82+
end
83+
rI .+= u .-
84+
85+
else
86+
a′ = sample(rng, Weights(σ, 1.0))
87+
env(get_legal_actions(env)[a′])
88+
u = external_sampling(env, i, nodes, rng)
89+
sI .+= σ
90+
u
91+
end
92+
end
93+
end
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
export OutcomeSamplingMCCFRPolicy
2+
3+
using Random
4+
using StatsBase: sample, Weights
5+
6+
"""
7+
OutcomeSamplingMCCFRPolicy
8+
9+
This implementation uses stochasticaly-weighted averaging.
10+
11+
Ref:
12+
13+
- [MONTE CARLO SAMPLING AND REGRET MINIMIZATION FOR EQUILIBRIUM COMPUTATION AND DECISION-MAKING IN LARGE EXTENSIVE FORM GAMES](http://mlanctot.info/files/papers/PhD_Thesis_MarcLanctot.pdf)
14+
- [Monte Carlo Sampling for Regret Minimization in Extensive Games](https://papers.nips.cc/paper/3713-monte-carlo-sampling-for-regret-minimization-in-extensive-games.pdf)
15+
"""
16+
struct OutcomeSamplingMCCFRPolicy{S,T,R<:AbstractRNG} <: AbstractPolicy
17+
nodes::Dict{S,InfoStateNode}
18+
behavior_policy::QBasedPolicy{TabularLearner{S,T},WeightedExplorer{true,R}}
19+
end
20+
21+
(p::OutcomeSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
22+
23+
RLBase.get_prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_policy, env)
24+
25+
function OutcomeSamplingMCCFRPolicy(;
26+
env::AbstractEnv,
27+
n_iter::Int,
28+
rng=Random.GLOBAL_RNG,
29+
ϵ=0.6
30+
)
31+
@assert NumAgentStyle(env) isa MultiAgent
32+
@assert DynamicStyle(env) === SEQUENTIAL
33+
@assert RewardStyle(env) === TERMINAL_REWARD
34+
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
35+
@assert DefaultStateStyle(env) === Information{String}()
36+
37+
nodes = init_info_state_nodes(env)
38+
39+
for i in 1:n_iter
40+
for p in get_players(env)
41+
if p != get_chance_player(env)
42+
outcome_sampling(copy(env), p, nodes, ϵ, 1.0, 1.0, 1.0, rng)
43+
end
44+
end
45+
end
46+
47+
behavior_policy = QBasedPolicy(;
48+
learner = TabularLearner{String}(),
49+
explorer = WeightedExplorer(; is_normalized = true, rng = rng),
50+
)
51+
52+
for (k, v) in nodes
53+
s = sum(v.cumulative_strategy)
54+
if s != 0
55+
update!(behavior_policy, k => v.cumulative_strategy ./ s)
56+
end
57+
end
58+
59+
OutcomeSamplingMCCFRPolicy(nodes, behavior_policy)
60+
end
61+
62+
function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
63+
current_player = get_current_player(env)
64+
65+
if get_terminal(env)
66+
get_reward(env, i) / s, 1.0
67+
elseif current_player == get_chance_player(env)
68+
env(rand(rng, get_actions(env)))
69+
outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
70+
else
71+
I = get_state(env)
72+
node = nodes[I]
73+
regret_matching!(node)
74+
σ, rI, sI = node.strategy, node.cumulative_regret, node.cumulative_strategy
75+
n = length(node.strategy)
76+
77+
if i == current_player
78+
aᵢ = rand(rng) >= ϵ ? sample(rng, Weights(σ, 1.0)) : rand(rng, 1:n)
79+
pᵢ = σ[aᵢ] * (1 - ϵ) + ϵ / n
80+
πᵢ′, π₋ᵢ′, s′ = πᵢ * pᵢ, π₋ᵢ, s * pᵢ
81+
else
82+
aᵢ = sample(rng, Weights(σ, 1.0))
83+
pᵢ = σ[aᵢ]
84+
πᵢ′, π₋ᵢ′, s′ = πᵢ, π₋ᵢ * pᵢ, s * pᵢ
85+
end
86+
87+
env(get_legal_actions(env)[aᵢ])
88+
u, πₜₐᵢₗ = outcome_sampling(env, i, nodes, ϵ, πᵢ′, π₋ᵢ′, s′, rng)
89+
90+
if i == current_player
91+
w = u * π₋ᵢ
92+
rI .+= w * πₜₐᵢₗ .* ((1:n .== aᵢ) .- σ[aᵢ])
93+
else
94+
sI .+= π₋ᵢ / s .* σ
95+
end
96+
97+
u, πₜₐᵢₗ * σ[aᵢ]
98+
end
99+
end

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
144144
end
145145
end
146146

147+
regret_matching!(node::InfoStateNode) = regret_matching!(node.strategy, node.cumulative_regret)
148+
147149
function regret_matching!(strategy, cumulative_regret)
148150
s = mapreduce(x -> max(0, x), +, cumulative_regret)
149151
if s > 0
@@ -155,6 +157,6 @@ end
155157

156158
function update_strategy!(nodes)
157159
for node in values(nodes)
158-
regret_matching!(node.strategy, node.cumulative_regret)
160+
regret_matching!(node)
159161
end
160162
end

0 commit comments

Comments
 (0)