|
| 1 | +export DeepCFR |
| 2 | + |
| 3 | +using Statistics: mean |
| 4 | +using StatsBase |
| 5 | + |
| 6 | +""" |
| 7 | + DeepCFR(;kwargs...) |
| 8 | +
|
| 9 | +Symbols used here follow the paper: [Deep Counterfactual Regret Minimization](https://arxiv.org/abs/1811.00164) |
| 10 | +
|
| 11 | +# Keyword arguments |
| 12 | +
|
| 13 | +- `K`, number of traverrsal. |
| 14 | +- `t`, number of iteration. |
| 15 | +- `Π`, the policy network. |
| 16 | +- `V`, a dictionary of each player's advantage network. |
| 17 | +- `MΠ`, a strategy memory. |
| 18 | +- `MV`, a dictionary of each player's advantage memory. |
| 19 | +- `reinitialize_freq=1`, the frequency of reinitializing the value networks. |
| 20 | +""" |
| 21 | +Base.@kwdef mutable struct DeepCFR{TP, TV, TMP, TMV, I, R, P} <: AbstractCFRPolicy |
| 22 | + Π::TP |
| 23 | + V::TV |
| 24 | + MΠ::TMP |
| 25 | + MV::TMV |
| 26 | + K::Int = 20 |
| 27 | + t::Int = 1 |
| 28 | + reinitialize_freq::Int = 1 |
| 29 | + batch_size_V::Int = 32 |
| 30 | + batch_size_Π::Int = 32 |
| 31 | + n_training_steps_V::Int = 1 |
| 32 | + n_training_steps_Π::Int = 1 |
| 33 | + rng::R = Random.GLOBAL_RNG |
| 34 | + initializer::I = glorot_normal(rng) |
| 35 | + max_grad_norm::Float32 = 10.0f0 |
| 36 | + # for logging |
| 37 | + Π_losses::Vector{Float32} = zeros(Float32, n_training_steps_Π) |
| 38 | + V_losses::Dict{P, Vector{Float32}} = Dict(k => zeros(Float32, n_training_steps_V) for (k,_) in MV) |
| 39 | + Π_norms::Vector{Float32} = zeros(Float32, n_training_steps_Π) |
| 40 | + V_norms::Dict{P, Vector{Float32}} = Dict(k => zeros(Float32, n_training_steps_V) for (k,_) in MV) |
| 41 | +end |
| 42 | + |
| 43 | +function RLBase.get_prob(π::DeepCFR, env::AbstractEnv) |
| 44 | + I = send_to_device(device(π.Π), get_state(env)) |
| 45 | + m = send_to_device(device(π.Π), ifelse.(get_legal_actions_mask(env), 0.f0, -Inf32)) |
| 46 | + logits = π.Π(Flux.unsqueeze(I, ndims(I)+1)) |> vec |
| 47 | + σ = softmax(logits .+ m) |
| 48 | + send_to_host(σ) |
| 49 | +end |
| 50 | + |
| 51 | +(π::DeepCFR)(env::AbstractEnv) = sample(π.rng, get_actions(env), Weights(get_prob(π, env), 1.0)) |
| 52 | + |
| 53 | +"Run one interation" |
| 54 | +function RLBase.update!(π::DeepCFR, env::AbstractEnv) |
| 55 | + for p in get_players(env) |
| 56 | + if p != get_chance_player(env) |
| 57 | + for k in 1:π.K |
| 58 | + external_sampling!(π, copy(env), p) |
| 59 | + end |
| 60 | + update_advantage_networks(π, p) |
| 61 | + end |
| 62 | + end |
| 63 | + π.t += 1 |
| 64 | +end |
| 65 | + |
| 66 | +"Update Π (policy network)" |
| 67 | +function RLBase.update!(π::DeepCFR) |
| 68 | + Π = π.Π |
| 69 | + Π_losses = π.Π_losses |
| 70 | + Π_norms = π.Π_norms |
| 71 | + D = device(Π) |
| 72 | + MΠ = π.MΠ |
| 73 | + ps = Flux.params(Π) |
| 74 | + |
| 75 | + for x in ps |
| 76 | + x .= π.initializer(size(x)...) |
| 77 | + end |
| 78 | + |
| 79 | + for i in 1:π.n_training_steps_Π |
| 80 | + batch_inds = rand(π.rng, 1:length(MΠ), π.batch_size_Π) |
| 81 | + I = send_to_device(D, Flux.batch([MΠ[:I][i] for i in batch_inds])) |
| 82 | + σ = send_to_device(D, Flux.batch([MΠ[:σ][i] for i in batch_inds])) |
| 83 | + t = send_to_device(D, Flux.batch([MΠ[:t][i] / π.t for i in batch_inds])) |
| 84 | + m = send_to_device(D, Flux.batch([ifelse.(MΠ[:m][i], 0.f0, -Inf32) for i in batch_inds])) |
| 85 | + gs = gradient(ps) do |
| 86 | + logits = Π(I) .+ m |
| 87 | + loss = mean(reshape(t, 1, :) .* ((σ .- softmax(logits)) .^ 2)) |
| 88 | + ignore() do |
| 89 | + # println(σ, "!!!",m, "===", Π(I)) |
| 90 | + Π_losses[i] = loss |
| 91 | + end |
| 92 | + loss |
| 93 | + end |
| 94 | + Π_norms[i] = clip_by_global_norm!(gs, ps, π.max_grad_norm) |
| 95 | + update!(Π, gs) |
| 96 | + end |
| 97 | +end |
| 98 | + |
| 99 | +"Update advantage network" |
| 100 | +function update_advantage_networks(π, p) |
| 101 | + V = π.V[p] |
| 102 | + V_losses = π.V_losses[p] |
| 103 | + V_norms = π.V_norms[p] |
| 104 | + MV = π.MV[p] |
| 105 | + if π.t % π.reinitialize_freq == 0 |
| 106 | + for x in Flux.params(V) |
| 107 | + # TODO: inplace |
| 108 | + x .= π.initializer(size(x)...) |
| 109 | + end |
| 110 | + end |
| 111 | + if length(MV) >= π.batch_size_V |
| 112 | + for i in 1:π.n_training_steps_V |
| 113 | + batch_inds = rand(π.rng, 1:length(MV), π.batch_size_V) |
| 114 | + I = send_to_device(device(V), Flux.batch([MV[:I][i] for i in batch_inds])) |
| 115 | + r̃ = send_to_device(device(V), Flux.batch([MV[:r̃][i] for i in batch_inds])) |
| 116 | + t = send_to_device(device(V), Flux.batch([MV[:t][i] / π.t for i in batch_inds])) |
| 117 | + m = send_to_device(device(V), Flux.batch([MV[:m][i] for i in batch_inds])) |
| 118 | + ps = Flux.params(V) |
| 119 | + gs = gradient(ps) do |
| 120 | + loss = mean(reshape(t, 1, :) .* ((r̃ .- V(I) .* m) .^ 2)) |
| 121 | + ignore() do |
| 122 | + V_losses[i] = loss |
| 123 | + end |
| 124 | + loss |
| 125 | + end |
| 126 | + V_norms[i] = clip_by_global_norm!(gs, ps, π.max_grad_norm) |
| 127 | + update!(V, gs) |
| 128 | + end |
| 129 | + end |
| 130 | +end |
| 131 | + |
| 132 | +"CFR Traversal with External Sampling" |
| 133 | +function external_sampling!(π::DeepCFR, env::AbstractEnv, p) |
| 134 | + if get_terminal(env) |
| 135 | + get_reward(env, p) |
| 136 | + elseif get_current_player(env) == get_chance_player(env) |
| 137 | + env(rand(π.rng, get_actions(env))) |
| 138 | + external_sampling!(π, env, p) |
| 139 | + elseif get_current_player(env) == p |
| 140 | + V = π.V[p] |
| 141 | + s = get_state(env) |
| 142 | + I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s)+1)) |
| 143 | + A = get_actions(env) |
| 144 | + m = get_legal_actions_mask(env) |
| 145 | + σ = masked_regret_matching(V(I) |> send_to_host |> vec, m) |
| 146 | + v = zeros(length(σ)) |
| 147 | + v̄ = 0. |
| 148 | + for i in 1:length(m) |
| 149 | + if m[i] |
| 150 | + v[i] = external_sampling!(π, child(env, A[i]), p) |
| 151 | + v̄ += σ[i] * v[i] |
| 152 | + end |
| 153 | + end |
| 154 | + push!(π.MV[p],I=s, t = π.t, r̃= (v .- v̄) .* m, m = m) |
| 155 | + v̄ |
| 156 | + else |
| 157 | + V = π.V[get_current_player(env)] |
| 158 | + s = get_state(env) |
| 159 | + I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s)+1)) |
| 160 | + A = get_actions(env) |
| 161 | + m = get_legal_actions_mask(env) |
| 162 | + σ = masked_regret_matching(V(I) |> send_to_host |> vec, m) |
| 163 | + push!(π.MΠ, I=s, t = π.t, σ=σ, m = m) |
| 164 | + a = sample(π.rng, A, Weights(σ, 1.0)) |
| 165 | + env(a) |
| 166 | + external_sampling!(π, env, p) |
| 167 | + end |
| 168 | +end |
| 169 | + |
| 170 | +"This is the specific regret matching method used in DeepCFR" |
| 171 | +function masked_regret_matching(v, m) |
| 172 | + v⁺ = max.(v .* m, 0.f0) |
| 173 | + s = sum(v⁺) |
| 174 | + if s > 0 |
| 175 | + v⁺ ./= s |
| 176 | + else |
| 177 | + fill!(v⁺, 0.f0) |
| 178 | + v⁺[findmax(v, m)[2]] = 1. |
| 179 | + end |
| 180 | + v⁺ |
| 181 | +end |
0 commit comments