@@ -6,10 +6,10 @@ struct InfoStateNode
6
6
cumulative_strategy:: Vector{Float64}
7
7
end
8
8
9
- InfoStateNode (n) = InfoStateNode (fill (1 / n, n), zeros (n), zeros (n))
9
+ InfoStateNode (n) = InfoStateNode (fill (1 / n, n), zeros (n), zeros (n))
10
10
11
11
function init_info_state_nodes (env:: AbstractEnv )
12
- nodes = Dict {String, InfoStateNode} ()
12
+ nodes = Dict {String,InfoStateNode} ()
13
13
walk (env) do x
14
14
if ! get_terminal (x) && get_current_player (x) != get_chance_player (x)
15
15
get! (nodes, get_state (x), InfoStateNode (length (get_legal_actions (x))))
24
24
See more details: [An Introduction to Counterfactual Regret Minimization](http://modelai.gettysburg.edu/2013/cfr/cfr.pdf)
25
25
"""
26
26
struct TabularCFRPolicy{S,T,R<: AbstractRNG } <: AbstractPolicy
27
- nodes:: Dict{S, InfoStateNode}
28
- behavior_policy:: QBasedPolicy{TabularLearner{S,T}, WeightedExplorer{true,R}}
27
+ nodes:: Dict{S,InfoStateNode}
28
+ behavior_policy:: QBasedPolicy{TabularLearner{S,T},WeightedExplorer{true,R}}
29
29
end
30
30
31
31
(p:: TabularCFRPolicy )(env:: AbstractEnv ) = p. behavior_policy (env)
@@ -35,7 +35,13 @@ RLBase.get_prob(p::TabularCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_pol
35
35
"""
36
36
TabularCFRPolicy(;n_iter::Int, env::AbstractEnv)
37
37
"""
38
- function TabularCFRPolicy (;n_iter:: Int , env:: AbstractEnv , rng= Random. GLOBAL_RNG, is_reset_neg_regrets= false , is_linear_averaging= false )
38
+ function TabularCFRPolicy (;
39
+ n_iter:: Int ,
40
+ env:: AbstractEnv ,
41
+ rng = Random. GLOBAL_RNG,
42
+ is_reset_neg_regrets = false ,
43
+ is_linear_averaging = false ,
44
+ )
39
45
@assert NumAgentStyle (env) isa MultiAgent
40
46
@assert DynamicStyle (env) === SEQUENTIAL
41
47
@assert RewardStyle (env) === TERMINAL_REWARD
@@ -47,7 +53,8 @@ function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG,
47
53
for i in 1 : n_iter
48
54
for p in get_players (env)
49
55
if p != get_chance_player (env)
50
- init_reach_prob = Dict (x=> 1.0 for x in get_players (env) if x != get_chance_player (env))
56
+ init_reach_prob =
57
+ Dict (x => 1.0 for x in get_players (env) if x != get_chance_player (env))
51
58
cfr! (nodes, env, p, init_reach_prob, 1.0 , is_linear_averaging ? i : 1 )
52
59
update_strategy! (nodes)
53
60
@@ -60,9 +67,12 @@ function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG,
60
67
end
61
68
end
62
69
63
- behavior_policy = QBasedPolicy (;learner= TabularLearner {String} (), explorer= WeightedExplorer (;is_normalized= true , rng= rng))
70
+ behavior_policy = QBasedPolicy (;
71
+ learner = TabularLearner {String} (),
72
+ explorer = WeightedExplorer (; is_normalized = true , rng = rng),
73
+ )
64
74
65
- for (k,v) in nodes
75
+ for (k, v) in nodes
66
76
s = sum (v. cumulative_strategy)
67
77
if s != 0
68
78
update! (behavior_policy, k => v. cumulative_strategy ./ s)
@@ -77,23 +87,39 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
77
87
get_reward (env, player)
78
88
else
79
89
if get_current_player (env) == get_chance_player (env)
80
- v = 0.
90
+ v = 0.0
81
91
for a:: ActionProbPair in get_legal_actions (env)
82
- v += a. prob * cfr! (nodes, child (env, a), player, reach_probs, chance_player_reach_prob * a. prob, ratio)
92
+ v +=
93
+ a. prob * cfr! (
94
+ nodes,
95
+ child (env, a),
96
+ player,
97
+ reach_probs,
98
+ chance_player_reach_prob * a. prob,
99
+ ratio,
100
+ )
83
101
end
84
102
v
85
103
else
86
- v = 0.
104
+ v = 0.0
87
105
node = nodes[get_state (env)]
88
106
legal_actions = get_legal_actions (env)
89
- U = player == get_current_player (env) ? Vector {Float64} (undef, length (legal_actions)) : nothing
107
+ U = player == get_current_player (env) ?
108
+ Vector {Float64} (undef, length (legal_actions)) : nothing
90
109
91
110
for (i, action) in enumerate (legal_actions)
92
111
prob = node. strategy[i]
93
112
new_reach_probs = copy (reach_probs)
94
113
new_reach_probs[get_current_player (env)] *= prob
95
114
96
- u = cfr! (nodes, child (env, action), player, new_reach_probs, chance_player_reach_prob, ratio)
115
+ u = cfr! (
116
+ nodes,
117
+ child (env, action),
118
+ player,
119
+ new_reach_probs,
120
+ chance_player_reach_prob,
121
+ ratio,
122
+ )
97
123
isnothing (U) || (U[i] = u)
98
124
v += prob * u
99
125
end
@@ -102,8 +128,13 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
102
128
reach_prob = reach_probs[player]
103
129
counterfactual_reach_prob = reduce (
104
130
* ,
105
- (reach_probs[p] for p in get_players (env) if p != player && p != get_chance_player (env));
106
- init= chance_player_reach_prob)
131
+ (
132
+ reach_probs[p]
133
+ for
134
+ p in get_players (env) if p != player && p != get_chance_player (env)
135
+ );
136
+ init = chance_player_reach_prob,
137
+ )
107
138
node. cumulative_regret .+ = counterfactual_reach_prob .* (U .- v)
108
139
node. cumulative_strategy .+ = ratio .* reach_prob .* node. strategy
109
140
end
@@ -113,16 +144,16 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
113
144
end
114
145
115
146
function regret_matching! (strategy, cumulative_regret)
116
- s = mapreduce (x-> max (0 ,x), + ,cumulative_regret)
147
+ s = mapreduce (x -> max (0 , x), + , cumulative_regret)
117
148
if s > 0
118
- strategy .= max .(0. , cumulative_regret) ./ s
149
+ strategy .= max .(0.0 , cumulative_regret) ./ s
119
150
else
120
- fill! (strategy, 1 / length (strategy))
151
+ fill! (strategy, 1 / length (strategy))
121
152
end
122
153
end
123
154
124
155
function update_strategy! (nodes)
125
156
for node in values (nodes)
126
157
regret_matching! (node. strategy, node. cumulative_regret)
127
158
end
128
- end
159
+ end
0 commit comments