Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit d21b82d

Browse files
Format .jl files (#84)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 28fcb99 commit d21b82d

File tree

6 files changed

+91
-51
lines changed

6 files changed

+91
-51
lines changed

src/algorithms/cfr/cfr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
include("tabular_cfr.jl")
1+
include("tabular_cfr.jl")

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ struct InfoStateNode
66
cumulative_strategy::Vector{Float64}
77
end
88

9-
InfoStateNode(n) = InfoStateNode(fill(1/n,n), zeros(n), zeros(n))
9+
InfoStateNode(n) = InfoStateNode(fill(1 / n, n), zeros(n), zeros(n))
1010

1111
function init_info_state_nodes(env::AbstractEnv)
12-
nodes = Dict{String, InfoStateNode}()
12+
nodes = Dict{String,InfoStateNode}()
1313
walk(env) do x
1414
if !get_terminal(x) && get_current_player(x) != get_chance_player(x)
1515
get!(nodes, get_state(x), InfoStateNode(length(get_legal_actions(x))))
@@ -24,8 +24,8 @@ end
2424
See more details: [An Introduction to Counterfactual Regret Minimization](http://modelai.gettysburg.edu/2013/cfr/cfr.pdf)
2525
"""
2626
struct TabularCFRPolicy{S,T,R<:AbstractRNG} <: AbstractPolicy
27-
nodes::Dict{S, InfoStateNode}
28-
behavior_policy::QBasedPolicy{TabularLearner{S,T}, WeightedExplorer{true,R}}
27+
nodes::Dict{S,InfoStateNode}
28+
behavior_policy::QBasedPolicy{TabularLearner{S,T},WeightedExplorer{true,R}}
2929
end
3030

3131
(p::TabularCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
@@ -35,7 +35,13 @@ RLBase.get_prob(p::TabularCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_pol
3535
"""
3636
TabularCFRPolicy(;n_iter::Int, env::AbstractEnv)
3737
"""
38-
function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG, is_reset_neg_regrets=false, is_linear_averaging=false)
38+
function TabularCFRPolicy(;
39+
n_iter::Int,
40+
env::AbstractEnv,
41+
rng = Random.GLOBAL_RNG,
42+
is_reset_neg_regrets = false,
43+
is_linear_averaging = false,
44+
)
3945
@assert NumAgentStyle(env) isa MultiAgent
4046
@assert DynamicStyle(env) === SEQUENTIAL
4147
@assert RewardStyle(env) === TERMINAL_REWARD
@@ -47,7 +53,8 @@ function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG,
4753
for i in 1:n_iter
4854
for p in get_players(env)
4955
if p != get_chance_player(env)
50-
init_reach_prob = Dict(x=>1.0 for x in get_players(env) if x != get_chance_player(env))
56+
init_reach_prob =
57+
Dict(x => 1.0 for x in get_players(env) if x != get_chance_player(env))
5158
cfr!(nodes, env, p, init_reach_prob, 1.0, is_linear_averaging ? i : 1)
5259
update_strategy!(nodes)
5360

@@ -60,9 +67,12 @@ function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG,
6067
end
6168
end
6269

63-
behavior_policy = QBasedPolicy(;learner=TabularLearner{String}(), explorer=WeightedExplorer(;is_normalized=true, rng=rng))
70+
behavior_policy = QBasedPolicy(;
71+
learner = TabularLearner{String}(),
72+
explorer = WeightedExplorer(; is_normalized = true, rng = rng),
73+
)
6474

65-
for (k,v) in nodes
75+
for (k, v) in nodes
6676
s = sum(v.cumulative_strategy)
6777
if s != 0
6878
update!(behavior_policy, k => v.cumulative_strategy ./ s)
@@ -77,23 +87,39 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
7787
get_reward(env, player)
7888
else
7989
if get_current_player(env) == get_chance_player(env)
80-
v = 0.
90+
v = 0.0
8191
for a::ActionProbPair in get_legal_actions(env)
82-
v += a.prob * cfr!(nodes, child(env, a), player, reach_probs, chance_player_reach_prob * a.prob, ratio)
92+
v +=
93+
a.prob * cfr!(
94+
nodes,
95+
child(env, a),
96+
player,
97+
reach_probs,
98+
chance_player_reach_prob * a.prob,
99+
ratio,
100+
)
83101
end
84102
v
85103
else
86-
v = 0.
104+
v = 0.0
87105
node = nodes[get_state(env)]
88106
legal_actions = get_legal_actions(env)
89-
U = player == get_current_player(env) ? Vector{Float64}(undef, length(legal_actions)) : nothing
107+
U = player == get_current_player(env) ?
108+
Vector{Float64}(undef, length(legal_actions)) : nothing
90109

91110
for (i, action) in enumerate(legal_actions)
92111
prob = node.strategy[i]
93112
new_reach_probs = copy(reach_probs)
94113
new_reach_probs[get_current_player(env)] *= prob
95114

96-
u = cfr!(nodes, child(env, action), player, new_reach_probs, chance_player_reach_prob, ratio)
115+
u = cfr!(
116+
nodes,
117+
child(env, action),
118+
player,
119+
new_reach_probs,
120+
chance_player_reach_prob,
121+
ratio,
122+
)
97123
isnothing(U) || (U[i] = u)
98124
v += prob * u
99125
end
@@ -102,8 +128,13 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
102128
reach_prob = reach_probs[player]
103129
counterfactual_reach_prob = reduce(
104130
*,
105-
(reach_probs[p] for p in get_players(env) if p != player && p != get_chance_player(env));
106-
init=chance_player_reach_prob)
131+
(
132+
reach_probs[p]
133+
for
134+
p in get_players(env) if p != player && p != get_chance_player(env)
135+
);
136+
init = chance_player_reach_prob,
137+
)
107138
node.cumulative_regret .+= counterfactual_reach_prob .* (U .- v)
108139
node.cumulative_strategy .+= ratio .* reach_prob .* node.strategy
109140
end
@@ -113,16 +144,16 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
113144
end
114145

115146
function regret_matching!(strategy, cumulative_regret)
116-
s = mapreduce(x->max(0,x), +,cumulative_regret)
147+
s = mapreduce(x -> max(0, x), +, cumulative_regret)
117148
if s > 0
118-
strategy .= max.(0., cumulative_regret) ./ s
149+
strategy .= max.(0.0, cumulative_regret) ./ s
119150
else
120-
fill!(strategy, 1/length(strategy))
151+
fill!(strategy, 1 / length(strategy))
121152
end
122153
end
123154

124155
function update_strategy!(nodes)
125156
for node in values(nodes)
126157
regret_matching!(node.strategy, node.cumulative_regret)
127158
end
128-
end
159+
end

src/algorithms/searching/minimax.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The minimax algorithm with [Alpha-beta pruning](https://en.wikipedia.org/wiki/Al
1010
Base.@kwdef mutable struct MinimaxPolicy{F} <: AbstractPolicy
1111
maximum_depth::Int = 30
1212
value_function::F = nothing
13-
v::Float64 = 0.
13+
v::Float64 = 0.0
1414
end
1515

1616
(p::MinimaxPolicy)(env::AbstractEnv) = p(env, DynamicStyle(env), NumAgentStyle(env))
@@ -19,7 +19,14 @@ function (p::MinimaxPolicy)(env::AbstractEnv, ::Sequential, ::MultiAgent{2})
1919
if get_terminal(env)
2020
rand(get_actions(env)) # just a dummy action
2121
else
22-
a, v = α_β_search(env, p.value_function, p.maximum_depth, -Inf, Inf, get_current_player(env))
22+
a, v = α_β_search(
23+
env,
24+
p.value_function,
25+
p.maximum_depth,
26+
-Inf,
27+
Inf,
28+
get_current_player(env),
29+
)
2330
p.v = v # for debug only
2431
a
2532
end
@@ -36,7 +43,7 @@ function α_β_search(env::AbstractEnv, value_function, depth, α, β, maximizin
3643
v = -Inf
3744
for a in legal_actions
3845
node = child(env, a)
39-
_, v_node = α_β_search(node, value_function, depth-1, α, β, maximizing_role)
46+
_, v_node = α_β_search(node, value_function, depth - 1, α, β, maximizing_role)
4047
if v_node > v
4148
v = v_node
4249
best_action = a
@@ -51,7 +58,7 @@ function α_β_search(env::AbstractEnv, value_function, depth, α, β, maximizin
5158
v = Inf
5259
for a in legal_actions
5360
node = child(env, a)
54-
_, v_node = α_β_search(node, value_function, depth-1, α, β, maximizing_role)
61+
_, v_node = α_β_search(node, value_function, depth - 1, α, β, maximizing_role)
5562
if v_node < v
5663
v = v_node
5764
best_action = a

src/algorithms/searching/searching.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
include("minimax.jl")
1+
include("minimax.jl")

src/experiments/open_spiel.jl

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
using Random
22

3-
function RLCore.Experiment(
4-
::Val{:JuliaRL},
5-
::Val{:Minimax},
6-
::Val{:OpenSpiel},
7-
game;
8-
)
3+
function RLCore.Experiment(::Val{:JuliaRL}, ::Val{:Minimax}, ::Val{:OpenSpiel}, game;)
94
env = OpenSpielEnv(string(game))
105
agents = (
11-
Agent(policy=MinimaxPolicy(), role=0),
12-
Agent(policy=MinimaxPolicy(), role=1)
6+
Agent(policy = MinimaxPolicy(), role = 0),
7+
Agent(policy = MinimaxPolicy(), role = 1),
138
)
149
hooks = (TotalRewardPerEpisode(), TotalRewardPerEpisode())
15-
description="""
16-
# Play `$game` in OpenSpiel with Minimax
17-
"""
10+
description = """
11+
# Play `$game` in OpenSpiel with Minimax
12+
"""
1813
Experiment(agents, env, StopAfterEpisode(1), hooks, description)
1914
end
2015

@@ -23,24 +18,31 @@ function RLCore.Experiment(
2318
::Val{:TabularCFR},
2419
::Val{:OpenSpiel},
2520
game;
26-
n_iter=300,
27-
seed=123
21+
n_iter = 300,
22+
seed = 123,
2823
)
29-
env = OpenSpielEnv(game;default_state_style=RLBase.Information{String}(), is_chance_agent_required=true)
24+
env = OpenSpielEnv(
25+
game;
26+
default_state_style = RLBase.Information{String}(),
27+
is_chance_agent_required = true,
28+
)
3029
rng = MersenneTwister(seed)
31-
π = TabularCFRPolicy(;n_iter=n_iter, env=env, rng=rng)
30+
π = TabularCFRPolicy(; n_iter = n_iter, env = env, rng = rng)
3231

3332
agents = map(get_players(env)) do p
3433
if p == get_chance_player(env)
35-
Agent(;policy=RandomPolicy(), role=p)
34+
Agent(; policy = RandomPolicy(), role = p)
3635
else
37-
Agent(;policy=π,role=p)
36+
Agent(; policy = π, role = p)
3837
end
3938
end
4039

41-
hooks = [p == get_chance_player(env) ? EmptyHook() : TotalRewardPerEpisode() for p in get_players(env)]
42-
description="""
43-
# Play `$game` in OpenSpiel with TabularCFRPolicy
44-
"""
40+
hooks = [
41+
p == get_chance_player(env) ? EmptyHook() : TotalRewardPerEpisode()
42+
for p in get_players(env)
43+
]
44+
description = """
45+
# Play `$game` in OpenSpiel with TabularCFRPolicy
46+
"""
4547
Experiment(agents, env, StopAfterEpisode(100_000), hooks, description)
46-
end
48+
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ using OpenSpiel
8282
@testset "TabularCFR" begin
8383
e = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
8484
run(e)
85-
@test isapprox(mean(e.hook[2].rewards), -1 / 18;atol=0.01)
86-
@test isapprox(mean(e.hook[3].rewards), 1 / 18;atol=0.01)
85+
@test isapprox(mean(e.hook[2].rewards), -1 / 18; atol = 0.01)
86+
@test isapprox(mean(e.hook[3].rewards), 1 / 18; atol = 0.01)
8787

8888
reset!(e.env)
8989
expected_values = Dict(expected_policy_values(e.agent, e.env))
90-
@test isapprox(expected_values[get_role(e.agent[2])], -1/18; atol=0.01)
91-
@test isapprox(expected_values[get_role(e.agent[3])], 1/18; atol=0.01)
90+
@test isapprox(expected_values[get_role(e.agent[2])], -1 / 18; atol = 0.01)
91+
@test isapprox(expected_values[get_role(e.agent[3])], 1 / 18; atol = 0.01)
9292
end
9393
end

0 commit comments

Comments
 (0)