Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 28fcb99

Browse files
authored
Multi agent related changes (#83)
* add experiment of snake game * sync * add Experiment for Minimax * add Experiment for CFRPolicy * fix CFR * add more experiments * add more experiments * update dependency * bump version * add more info in README.md * bugfix * minor bugfix * automatically decrease steps in CI * increase steps in JuliaRL_TabularCFR_OpenSpiel
1 parent 39159f4 commit 28fcb99

File tree

12 files changed

+277
-9
lines changed

12 files changed

+277
-9
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
language: julia
33
os:
44
- linux
5-
- osx
5+
# - osx
66
julia:
77
- 1.4
88
- nightly

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReinforcementLearningZoo"
22
uuid = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
33
authors = ["Jun Tian <tianjun.cpp@gmail.com>"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -32,7 +32,7 @@ Distributions = "0.23"
3232
Flux = "0.11"
3333
MacroTools = "0.5"
3434
ReinforcementLearningBase = "0.8"
35-
ReinforcementLearningCore = "0.4.1"
35+
ReinforcementLearningCore = "0.4.2"
3636
Requires = "1"
3737
Setfield = "0.6, 0.7"
3838
StatsBase = "0.32, 0.33"
@@ -41,8 +41,9 @@ Zygote = "0.5"
4141
julia = "1.4"
4242

4343
[extras]
44+
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
4445
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
4546
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4647

4748
[targets]
48-
test = ["Test", "ReinforcementLearningEnvironments"]
49+
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel"]

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ This project aims to provide some implementations of the most typical reinforcem
2424
- PPO
2525
- DDPG
2626
- SAC
27+
- CFR
28+
- Minimax
2729

2830
If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).
2931

@@ -45,6 +47,9 @@ Some built-in experiments are exported to help new users to easily run benchmark
4547
- ``E`JuliaRL_SAC_Pendulum` `` (Thanks to [@rbange](https://github.com/rbange))
4648
- ``E`JuliaRL_BasicDQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
4749
- ``E`JuliaRL_DQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
50+
- ``E`JuliaRL_Minimax_OpenSpiel(tic_tac_toe)` ``
51+
- ``E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` ``
52+
- ``E`JuliaRL_DQN_SnakeGame` ``
4853
- ``E`Dopamine_DQN_Atari(pong)` ``
4954
- ``E`Dopamine_Rainbow_Atari(pong)` ``
5055
- ``E`Dopamine_IQN_Atari(pong)` ``
@@ -56,7 +61,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
5661
- Experiments on `CartPole` usually run faster with CPU only due to the overhead of sending data between CPU and GPU.
5762
- It shouldn't surprise you that our experiments on `CartPole` are much faster than those written in Python. The secret is that our environment is written in Julia!
5863
- Remember to set `JULIA_NUM_THREADS` to enable multi-threading when using algorithms like `A2C` and `PPO`.
59-
- Experiments on `Atari` are only available when you have `ArcadeLearningEnvironment.jl` installed and `using ArcadeLearningEnvironment`.
64+
- Experiments on `Atari` (`OpenSpiel`, `SnakeGame`) are only available after you have `ArcadeLearningEnvironment.jl` (`OpenSpiel.jl`, `SnakeGame.jl`) installed and `using ArcadeLearningEnvironment` (`using OpenSpiel`, `using SnakeGame`).
6065

6166
### Speed
6267

src/ReinforcementLearningZoo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ function __init__()
1919
include("experiments/rl_envs.jl")
2020
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari.jl")
2121
@require SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429" include("experiments/snake.jl")
22+
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel.jl")
2223
end
2324
end
2425

src/algorithms/algorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
include("dqns/dqns.jl")
22
include("policy_gradient/policy_gradient.jl")
3+
include("searching/searching.jl")
4+
include("cfr/cfr.jl")

src/algorithms/cfr/cfr.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("tabular_cfr.jl")

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
export TabularCFRPolicy
2+
3+
struct InfoStateNode
4+
strategy::Vector{Float64}
5+
cumulative_regret::Vector{Float64}
6+
cumulative_strategy::Vector{Float64}
7+
end
8+
9+
InfoStateNode(n) = InfoStateNode(fill(1/n,n), zeros(n), zeros(n))
10+
11+
function init_info_state_nodes(env::AbstractEnv)
12+
nodes = Dict{String, InfoStateNode}()
13+
walk(env) do x
14+
if !get_terminal(x) && get_current_player(x) != get_chance_player(x)
15+
get!(nodes, get_state(x), InfoStateNode(length(get_legal_actions(x))))
16+
end
17+
end
18+
nodes
19+
end
20+
21+
"""
22+
TabularCFRPolicy
23+
24+
See more details: [An Introduction to Counterfactual Regret Minimization](http://modelai.gettysburg.edu/2013/cfr/cfr.pdf)
25+
"""
26+
struct TabularCFRPolicy{S,T,R<:AbstractRNG} <: AbstractPolicy
27+
nodes::Dict{S, InfoStateNode}
28+
behavior_policy::QBasedPolicy{TabularLearner{S,T}, WeightedExplorer{true,R}}
29+
end
30+
31+
(p::TabularCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
32+
33+
RLBase.get_prob(p::TabularCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_policy, env)
34+
35+
"""
36+
TabularCFRPolicy(;n_iter::Int, env::AbstractEnv)
37+
"""
38+
function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG, is_reset_neg_regrets=false, is_linear_averaging=false)
39+
@assert NumAgentStyle(env) isa MultiAgent
40+
@assert DynamicStyle(env) === SEQUENTIAL
41+
@assert RewardStyle(env) === TERMINAL_REWARD
42+
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
43+
@assert DefaultStateStyle(env) === Information{String}()
44+
45+
nodes = init_info_state_nodes(env)
46+
47+
for i in 1:n_iter
48+
for p in get_players(env)
49+
if p != get_chance_player(env)
50+
init_reach_prob = Dict(x=>1.0 for x in get_players(env) if x != get_chance_player(env))
51+
cfr!(nodes, env, p, init_reach_prob, 1.0, is_linear_averaging ? i : 1)
52+
update_strategy!(nodes)
53+
54+
if is_reset_neg_regrets
55+
for node in values(nodes)
56+
node.cumulative_regret .= max.(node.cumulative_regret, 0)
57+
end
58+
end
59+
end
60+
end
61+
end
62+
63+
behavior_policy = QBasedPolicy(;learner=TabularLearner{String}(), explorer=WeightedExplorer(;is_normalized=true, rng=rng))
64+
65+
for (k,v) in nodes
66+
s = sum(v.cumulative_strategy)
67+
if s != 0
68+
update!(behavior_policy, k => v.cumulative_strategy ./ s)
69+
end
70+
end
71+
72+
TabularCFRPolicy(nodes, behavior_policy)
73+
end
74+
75+
function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
76+
if get_terminal(env)
77+
get_reward(env, player)
78+
else
79+
if get_current_player(env) == get_chance_player(env)
80+
v = 0.
81+
for a::ActionProbPair in get_legal_actions(env)
82+
v += a.prob * cfr!(nodes, child(env, a), player, reach_probs, chance_player_reach_prob * a.prob, ratio)
83+
end
84+
v
85+
else
86+
v = 0.
87+
node = nodes[get_state(env)]
88+
legal_actions = get_legal_actions(env)
89+
U = player == get_current_player(env) ? Vector{Float64}(undef, length(legal_actions)) : nothing
90+
91+
for (i, action) in enumerate(legal_actions)
92+
prob = node.strategy[i]
93+
new_reach_probs = copy(reach_probs)
94+
new_reach_probs[get_current_player(env)] *= prob
95+
96+
u = cfr!(nodes, child(env, action), player, new_reach_probs, chance_player_reach_prob, ratio)
97+
isnothing(U) || (U[i] = u)
98+
v += prob * u
99+
end
100+
101+
if player == get_current_player(env)
102+
reach_prob = reach_probs[player]
103+
counterfactual_reach_prob = reduce(
104+
*,
105+
(reach_probs[p] for p in get_players(env) if p != player && p != get_chance_player(env));
106+
init=chance_player_reach_prob)
107+
node.cumulative_regret .+= counterfactual_reach_prob .* (U .- v)
108+
node.cumulative_strategy .+= ratio .* reach_prob .* node.strategy
109+
end
110+
v
111+
end
112+
end
113+
end
114+
115+
function regret_matching!(strategy, cumulative_regret)
116+
s = mapreduce(x->max(0,x), +,cumulative_regret)
117+
if s > 0
118+
strategy .= max.(0., cumulative_regret) ./ s
119+
else
120+
fill!(strategy, 1/length(strategy))
121+
end
122+
end
123+
124+
function update_strategy!(nodes)
125+
for node in values(nodes)
126+
regret_matching!(node.strategy, node.cumulative_regret)
127+
end
128+
end

src/algorithms/searching/minimax.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
export MinimaxPolicy
2+
3+
"""
4+
MinimaxPolicy(;value_function, depth::Int)
5+
The minimax algorithm with [Alpha-beta pruning](https://en.wikipedia.org/wiki/Alpha-beta_pruning)
6+
## Keyword Arguments
7+
- `maximum_depth::Int=30`, the maximum depth of search.
8+
- `value_function=nothing`, estimate the value of `env`. `value_function(env) -> Number`. It is only called after searching for `maximum_depth` and the `env` is not terminated yet.
9+
"""
10+
Base.@kwdef mutable struct MinimaxPolicy{F} <: AbstractPolicy
11+
maximum_depth::Int = 30
12+
value_function::F = nothing
13+
v::Float64 = 0.
14+
end
15+
16+
(p::MinimaxPolicy)(env::AbstractEnv) = p(env, DynamicStyle(env), NumAgentStyle(env))
17+
18+
function (p::MinimaxPolicy)(env::AbstractEnv, ::Sequential, ::MultiAgent{2})
19+
if get_terminal(env)
20+
rand(get_actions(env)) # just a dummy action
21+
else
22+
a, v = α_β_search(env, p.value_function, p.maximum_depth, -Inf, Inf, get_current_player(env))
23+
p.v = v # for debug only
24+
a
25+
end
26+
end
27+
28+
function α_β_search(env::AbstractEnv, value_function, depth, α, β, maximizing_role)
29+
if get_terminal(env)
30+
nothing, get_reward(env, maximizing_role)
31+
elseif depth == 0
32+
nothing, value_function(env)
33+
elseif get_current_player(env) == maximizing_role
34+
legal_actions = get_legal_actions(env)
35+
best_action = legal_actions[1]
36+
v = -Inf
37+
for a in legal_actions
38+
node = child(env, a)
39+
_, v_node = α_β_search(node, value_function, depth-1, α, β, maximizing_role)
40+
if v_node > v
41+
v = v_node
42+
best_action = a
43+
end
44+
α = max(α, v)
45+
α >= β && break # β cut-off
46+
end
47+
best_action, v
48+
else
49+
legal_actions = get_legal_actions(env)
50+
best_action = legal_actions[1]
51+
v = Inf
52+
for a in legal_actions
53+
node = child(env, a)
54+
_, v_node = α_β_search(node, value_function, depth-1, α, β, maximizing_role)
55+
if v_node < v
56+
v = v_node
57+
best_action = a
58+
end
59+
β = min(β, v)
60+
β <= α && break # α cut-off
61+
end
62+
best_action, v
63+
end
64+
end

src/algorithms/searching/searching.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("minimax.jl")

src/experiments/open_spiel.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using Random
2+
3+
function RLCore.Experiment(
4+
::Val{:JuliaRL},
5+
::Val{:Minimax},
6+
::Val{:OpenSpiel},
7+
game;
8+
)
9+
env = OpenSpielEnv(string(game))
10+
agents = (
11+
Agent(policy=MinimaxPolicy(), role=0),
12+
Agent(policy=MinimaxPolicy(), role=1)
13+
)
14+
hooks = (TotalRewardPerEpisode(), TotalRewardPerEpisode())
15+
description="""
16+
# Play `$game` in OpenSpiel with Minimax
17+
"""
18+
Experiment(agents, env, StopAfterEpisode(1), hooks, description)
19+
end
20+
21+
function RLCore.Experiment(
22+
::Val{:JuliaRL},
23+
::Val{:TabularCFR},
24+
::Val{:OpenSpiel},
25+
game;
26+
n_iter=300,
27+
seed=123
28+
)
29+
env = OpenSpielEnv(game;default_state_style=RLBase.Information{String}(), is_chance_agent_required=true)
30+
rng = MersenneTwister(seed)
31+
π = TabularCFRPolicy(;n_iter=n_iter, env=env, rng=rng)
32+
33+
agents = map(get_players(env)) do p
34+
if p == get_chance_player(env)
35+
Agent(;policy=RandomPolicy(), role=p)
36+
else
37+
Agent(;policy=π,role=p)
38+
end
39+
end
40+
41+
hooks = [p == get_chance_player(env) ? EmptyHook() : TotalRewardPerEpisode() for p in get_players(env)]
42+
description="""
43+
# Play `$game` in OpenSpiel with TabularCFRPolicy
44+
"""
45+
Experiment(agents, env, StopAfterEpisode(100_000), hooks, description)
46+
end

0 commit comments

Comments
 (0)