Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 59a2475

Browse files
authored
Improve CFR (#99)
* add an AbstractCFRPolicy * improve tabular cfr * add best response policy * add nash_conv * fix external_sampling * fix outcome_sampling * move out function from test * add deepcfr * update README * add experiment for DeepCFR * set seed * update experiment result * update dependency * update experiment list * use StableRNGs instead * use StableRNGs in Experiments by default * resolve test errors
1 parent 6183609 commit 59a2475

20 files changed

+978
-216
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1919
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
2020
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2121
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
22+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2223
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2324
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2425
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
@@ -29,13 +30,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2930
AbstractTrees = "0.3"
3031
BSON = "0.2"
3132
CUDA = "1"
32-
Distributions = "0.23, 0.24"
33+
Distributions = "0.24"
3334
Flux = "0.11"
3435
MacroTools = "0.5"
3536
ReinforcementLearningBase = "0.8.4"
36-
ReinforcementLearningCore = "0.4.5"
37+
ReinforcementLearningCore = "0.5"
3738
Requires = "1"
3839
Setfield = "0.6, 0.7"
40+
StableRNGs = "1.0"
3941
StatsBase = "0.32, 0.33"
4042
StructArrays = "0.4"
4143
TensorBoardLogger = "0.1"

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ This project aims to provide some implementations of the most typical reinforcem
2525
- DDPG
2626
- TD3
2727
- SAC
28-
- CFR/OS-MCCFR/ES-MCCFR
28+
- CFR/OS-MCCFR/ES-MCCFR/DeepCFR
2929
- Minimax
3030

3131
If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).
@@ -55,6 +55,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
5555
- ``E`JuliaRL_DQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
5656
- ``E`JuliaRL_Minimax_OpenSpiel(tic_tac_toe)` ``
5757
- ``E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` ``
58+
- ``E`JuliaRL_DeepCFR_OpenSpiel(leduc_poker)` ``
5859
- ``E`JuliaRL_DQN_SnakeGame` ``
5960
- ``E`Dopamine_DQN_Atari(pong)` ``
6061
- ``E`Dopamine_Rainbow_Atari(pong)` ``

src/ReinforcementLearningZoo.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export RLZoo
66
using ReinforcementLearningBase
77
using ReinforcementLearningCore
88
using Setfield: @set
9+
using StableRNGs
910

1011
include("patch.jl")
1112
include("algorithms/algorithms.jl")
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
abstract type AbstractCFRPolicy <: AbstractPolicy end
2+
3+
function Base.run(p::AbstractCFRPolicy, env::AbstractEnv, stop_condition=StopAfterStep(1), hook=EmptyHook())
4+
@assert NumAgentStyle(env) isa MultiAgent
5+
@assert DynamicStyle(env) === SEQUENTIAL
6+
@assert RewardStyle(env) === TERMINAL_REWARD
7+
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
8+
@assert DefaultStateStyle(env) isa Information
9+
10+
RLBase.reset!(env)
11+
12+
while true
13+
update!(p, env)
14+
hook(POST_ACT_STAGE, p, env)
15+
stop_condition(p, env) && break
16+
end
17+
update!(p)
18+
end
+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
export BestResponsePolicy
2+
3+
using Flux:onehot
4+
5+
struct BestResponsePolicy{E, S, A, X, P<:AbstractPolicy} <: AbstractCFRPolicy
6+
cfr_reach_prob::Dict{S, Vector{Pair{E, Float64}}}
7+
best_response_action_cache::Dict{S,A}
8+
best_response_value_cache::Dict{E,Float64}
9+
best_responder::X
10+
policy::P
11+
end
12+
13+
"""
14+
BestResponsePolicy(policy, env, best_responder)
15+
16+
- `policy`, the original policy to be wrapped in the best response policy.
17+
- `env`, the environment to handle.
18+
- `best_responder`, the player to choose best response action.
19+
"""
20+
function BestResponsePolicy(policy, env, best_responder; state_type=String, action_type=Int)
21+
# S = typeof(get_state(env)) # TODO: currently it will break the OpenSpielEnv. Can not get information set for chance player
22+
# A = eltype(get_actions(env)) # TODO: for chance players it will return ActionProbPair
23+
S = state_type
24+
A = action_type
25+
E = typeof(env)
26+
27+
p = BestResponsePolicy(
28+
Dict{S, Vector{Pair{E, Float64}}}(),
29+
Dict{S, A}(),
30+
Dict{E, Float64}(),
31+
best_responder,
32+
policy
33+
)
34+
35+
e = copy(env)
36+
@assert e == env "The copy method doesn't seem to be implemented for environment: $env"
37+
@assert hash(e) == hash(env) "The hash method doesn't seem to be implemented for environment: $env"
38+
RLBase.reset!(e) # start from the root!
39+
init_cfr_reach_prob!(p, e)
40+
p
41+
end
42+
43+
function (p::BestResponsePolicy)(env::AbstractEnv)
44+
if get_current_player(env) == p.best_responder
45+
best_response_action(p, env)
46+
else
47+
p.policy(env)
48+
end
49+
end
50+
51+
function init_cfr_reach_prob!(p, env, reach_prob=1.0)
52+
if !get_terminal(env)
53+
if get_current_player(env) == p.best_responder
54+
push!(get!(p.cfr_reach_prob, get_state(env), []), env => reach_prob)
55+
56+
for a in get_legal_actions(env)
57+
init_cfr_reach_prob!(p, child(env, a), reach_prob)
58+
end
59+
elseif get_current_player(env) == get_chance_player(env)
60+
for a::ActionProbPair in get_actions(env)
61+
init_cfr_reach_prob!(p, child(env, a), reach_prob * a.prob)
62+
end
63+
else # opponents
64+
for a in get_legal_actions(env)
65+
init_cfr_reach_prob!(p, child(env, a), reach_prob * get_prob(p.policy, env, a))
66+
end
67+
end
68+
end
69+
end
70+
71+
function best_response_value(p, env)
72+
get!(p.best_response_value_cache, env) do
73+
if get_terminal(env)
74+
get_reward(env, p.best_responder)
75+
elseif get_current_player(env) == p.best_responder
76+
a = best_response_action(p, env)
77+
best_response_value(p, child(env, a))
78+
elseif get_current_player(env) == get_chance_player(env)
79+
v = 0.
80+
for a::ActionProbPair in get_actions(env)
81+
v += a.prob * best_response_value(p, child(env, a))
82+
end
83+
v
84+
else
85+
v = 0.
86+
for a in get_legal_actions(env)
87+
v += get_prob(p.policy, env, a) * best_response_value(p, child(env, a))
88+
end
89+
v
90+
end
91+
end
92+
end
93+
94+
function best_response_action(p, env)
95+
get!(p.best_response_action_cache, get_state(env)) do
96+
best_action, best_action_value = nothing, typemin(Float64)
97+
for a in get_legal_actions(env)
98+
# for each information set (`get_state(env)` here), we may have several paths to reach it
99+
# here we sum the cfr reach prob weighted value to find out the best action
100+
v = sum(p.cfr_reach_prob[get_state(env)]) do (e, reach_prob)
101+
reach_prob * best_response_value(p, child(e, a))
102+
end
103+
if v > best_action_value
104+
best_action, best_action_value = a, v
105+
end
106+
end
107+
best_action
108+
end
109+
end
110+
111+
RLBase.update!(p::BestResponsePolicy, args...) = nothing
112+
113+
function RLBase.get_prob(p::BestResponsePolicy, env::AbstractEnv)
114+
if get_current_player(env) == p.best_responder
115+
onehot(p(env), get_actions(env))
116+
else
117+
get_prob(p.policy, env)
118+
end
119+
end

src/algorithms/cfr/cfr.jl

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
include("abstract_cfr_policy.jl")
12
include("tabular_cfr.jl")
23
include("outcome_sampling_mccfr.jl")
34
include("external_sampling_mccfr.jl")
5+
include("best_response_policy.jl")
6+
include("nash_conv.jl")
7+
include("deep_cfr.jl")

src/algorithms/cfr/deep_cfr.jl

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
export DeepCFR
2+
3+
using Statistics: mean
4+
using StatsBase
5+
6+
"""
7+
DeepCFR(;kwargs...)
8+
9+
Symbols used here follow the paper: [Deep Counterfactual Regret Minimization](https://arxiv.org/abs/1811.00164)
10+
11+
# Keyword arguments
12+
13+
- `K`, number of traverrsal.
14+
- `t`, number of iteration.
15+
- `Π`, the policy network.
16+
- `V`, a dictionary of each player's advantage network.
17+
- `MΠ`, a strategy memory.
18+
- `MV`, a dictionary of each player's advantage memory.
19+
- `reinitialize_freq=1`, the frequency of reinitializing the value networks.
20+
"""
21+
Base.@kwdef mutable struct DeepCFR{TP, TV, TMP, TMV, I, R, P} <: AbstractCFRPolicy
22+
Π::TP
23+
V::TV
24+
::TMP
25+
MV::TMV
26+
K::Int = 20
27+
t::Int = 1
28+
reinitialize_freq::Int = 1
29+
batch_size_V::Int = 32
30+
batch_size_Π::Int = 32
31+
n_training_steps_V::Int = 1
32+
n_training_steps_Π::Int = 1
33+
rng::R = Random.GLOBAL_RNG
34+
initializer::I = glorot_normal(rng)
35+
max_grad_norm::Float32 = 10.0f0
36+
# for logging
37+
Π_losses::Vector{Float32} = zeros(Float32, n_training_steps_Π)
38+
V_losses::Dict{P, Vector{Float32}} = Dict(k => zeros(Float32, n_training_steps_V) for (k,_) in MV)
39+
Π_norms::Vector{Float32} = zeros(Float32, n_training_steps_Π)
40+
V_norms::Dict{P, Vector{Float32}} = Dict(k => zeros(Float32, n_training_steps_V) for (k,_) in MV)
41+
end
42+
43+
function RLBase.get_prob::DeepCFR, env::AbstractEnv)
44+
I = send_to_device(device.Π), get_state(env))
45+
m = send_to_device(device.Π), ifelse.(get_legal_actions_mask(env), 0.f0, -Inf32))
46+
logits = π.Π(Flux.unsqueeze(I, ndims(I)+1)) |> vec
47+
σ = softmax(logits .+ m)
48+
send_to_host(σ)
49+
end
50+
51+
::DeepCFR)(env::AbstractEnv) = sample.rng, get_actions(env), Weights(get_prob(π, env), 1.0))
52+
53+
"Run one interation"
54+
function RLBase.update!::DeepCFR, env::AbstractEnv)
55+
for p in get_players(env)
56+
if p != get_chance_player(env)
57+
for k in 1:π.K
58+
external_sampling!(π, copy(env), p)
59+
end
60+
update_advantage_networks(π, p)
61+
end
62+
end
63+
π.t += 1
64+
end
65+
66+
"Update Π (policy network)"
67+
function RLBase.update!::DeepCFR)
68+
Π = π.Π
69+
Π_losses = π.Π_losses
70+
Π_norms = π.Π_norms
71+
D = device(Π)
72+
= π.
73+
ps = Flux.params(Π)
74+
75+
for x in ps
76+
x .= π.initializer(size(x)...)
77+
end
78+
79+
for i in 1:π.n_training_steps_Π
80+
batch_inds = rand.rng, 1:length(MΠ), π.batch_size_Π)
81+
I = send_to_device(D, Flux.batch([MΠ[:I][i] for i in batch_inds]))
82+
σ = send_to_device(D, Flux.batch([MΠ[][i] for i in batch_inds]))
83+
t = send_to_device(D, Flux.batch([MΠ[:t][i] / π.t for i in batch_inds]))
84+
m = send_to_device(D, Flux.batch([ifelse.(MΠ[:m][i], 0.f0, -Inf32) for i in batch_inds]))
85+
gs = gradient(ps) do
86+
logits = Π(I) .+ m
87+
loss = mean(reshape(t, 1, :) .* ((σ .- softmax(logits)) .^ 2))
88+
ignore() do
89+
# println(σ, "!!!",m, "===", Π(I))
90+
Π_losses[i] = loss
91+
end
92+
loss
93+
end
94+
Π_norms[i] = clip_by_global_norm!(gs, ps, π.max_grad_norm)
95+
update!(Π, gs)
96+
end
97+
end
98+
99+
"Update advantage network"
100+
function update_advantage_networks(π, p)
101+
V = π.V[p]
102+
V_losses = π.V_losses[p]
103+
V_norms = π.V_norms[p]
104+
MV = π.MV[p]
105+
if π.t % π.reinitialize_freq == 0
106+
for x in Flux.params(V)
107+
# TODO: inplace
108+
x .= π.initializer(size(x)...)
109+
end
110+
end
111+
if length(MV) >= π.batch_size_V
112+
for i in 1:π.n_training_steps_V
113+
batch_inds = rand.rng, 1:length(MV), π.batch_size_V)
114+
I = send_to_device(device(V), Flux.batch([MV[:I][i] for i in batch_inds]))
115+
= send_to_device(device(V), Flux.batch([MV[:r̃][i] for i in batch_inds]))
116+
t = send_to_device(device(V), Flux.batch([MV[:t][i] / π.t for i in batch_inds]))
117+
m = send_to_device(device(V), Flux.batch([MV[:m][i] for i in batch_inds]))
118+
ps = Flux.params(V)
119+
gs = gradient(ps) do
120+
loss = mean(reshape(t, 1, :) .* ((r̃ .- V(I) .* m) .^ 2))
121+
ignore() do
122+
V_losses[i] = loss
123+
end
124+
loss
125+
end
126+
V_norms[i] = clip_by_global_norm!(gs, ps, π.max_grad_norm)
127+
update!(V, gs)
128+
end
129+
end
130+
end
131+
132+
"CFR Traversal with External Sampling"
133+
function external_sampling!::DeepCFR, env::AbstractEnv, p)
134+
if get_terminal(env)
135+
get_reward(env, p)
136+
elseif get_current_player(env) == get_chance_player(env)
137+
env(rand.rng, get_actions(env)))
138+
external_sampling!(π, env, p)
139+
elseif get_current_player(env) == p
140+
V = π.V[p]
141+
s = get_state(env)
142+
I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s)+1))
143+
A = get_actions(env)
144+
m = get_legal_actions_mask(env)
145+
σ = masked_regret_matching(V(I) |> send_to_host |> vec, m)
146+
v = zeros(length(σ))
147+
= 0.
148+
for i in 1:length(m)
149+
if m[i]
150+
v[i] = external_sampling!(π, child(env, A[i]), p)
151+
+= σ[i] * v[i]
152+
end
153+
end
154+
push!.MV[p],I=s, t = π.t, r̃= (v .- v̄) .* m, m = m)
155+
156+
else
157+
V = π.V[get_current_player(env)]
158+
s = get_state(env)
159+
I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s)+1))
160+
A = get_actions(env)
161+
m = get_legal_actions_mask(env)
162+
σ = masked_regret_matching(V(I) |> send_to_host |> vec, m)
163+
push!.MΠ, I=s, t = π.t, σ=σ, m = m)
164+
a = sample.rng, A, Weights(σ, 1.0))
165+
env(a)
166+
external_sampling!(π, env, p)
167+
end
168+
end
169+
170+
"This is the specific regret matching method used in DeepCFR"
171+
function masked_regret_matching(v, m)
172+
v⁺ = max.(v .* m, 0.f0)
173+
s = sum(v⁺)
174+
if s > 0
175+
v⁺ ./= s
176+
else
177+
fill!(v⁺, 0.f0)
178+
v⁺[findmax(v, m)[2]] = 1.
179+
end
180+
v⁺
181+
end

0 commit comments

Comments
 (0)