1
1
include (" ppo_trajectory.jl" )
2
2
3
3
using Random
4
+ using Distributions: Categorical, Normal, logpdf
5
+ using StructArrays
4
6
5
- export PPOLearner
7
+ export PPOPolicy
6
8
7
9
"""
8
- PPOLearner (;kwargs)
10
+ PPOPolicy (;kwargs)
9
11
10
12
# Keyword arguments
11
13
@@ -19,9 +21,13 @@ export PPOLearner
19
21
- `actor_loss_weight = 1.0f0`,
20
22
- `critic_loss_weight = 0.5f0`,
21
23
- `entropy_loss_weight = 0.01f0`,
24
+ - `dist = Categorical`,
22
25
- `rng = Random.GLOBAL_RNG`,
26
+
27
+ By default, `dist` is set to `Categorical`, which means it will only works
28
+ on environments of discrete actions. To work with environments of
23
29
"""
24
- mutable struct PPOLearner {A<: ActorCritic ,R} <: AbstractLearner
30
+ mutable struct PPOPolicy {A<: ActorCritic ,D, R} <: AbstractPolicy
25
31
approximator:: A
26
32
γ:: Float32
27
33
λ:: Float32
@@ -41,7 +47,7 @@ mutable struct PPOLearner{A<:ActorCritic,R} <: AbstractLearner
41
47
loss:: Matrix{Float32}
42
48
end
43
49
44
- function PPOLearner (;
50
+ function PPOPolicy (;
45
51
approximator,
46
52
γ = 0.99f0 ,
47
53
λ = 0.95f0 ,
@@ -52,9 +58,10 @@ function PPOLearner(;
52
58
actor_loss_weight = 1.0f0 ,
53
59
critic_loss_weight = 0.5f0 ,
54
60
entropy_loss_weight = 0.01f0 ,
61
+ dist = Categorical,
55
62
rng = Random. GLOBAL_RNG,
56
63
)
57
- PPOLearner (
64
+ PPOPolicy {typeof(approximator),dist,typeof(rng)} (
58
65
approximator,
59
66
γ,
60
67
λ,
@@ -74,21 +81,33 @@ function PPOLearner(;
74
81
)
75
82
end
76
83
77
- function (learner:: PPOLearner )(env:: MultiThreadEnv )
78
- learner. approximator. actor (send_to_device (
79
- device (learner. approximator),
80
- get_state (env),
81
- )) |> send_to_host
84
+ function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}}, Normal} , state:: AbstractArray )
85
+ p. approximator. actor (send_to_device (
86
+ device (p. approximator),
87
+ state,
88
+ )) |> send_to_host |> StructArray{Normal}
89
+ end
90
+
91
+ function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic, Categorical} , state:: AbstractArray )
92
+ logits = p. approximator. actor (send_to_device (
93
+ device (p. approximator),
94
+ state,
95
+ )) |> softmax |> send_to_host
96
+ [Categorical (x;check_args= false ) for x in eachcol (logits)]
82
97
end
83
98
84
- function (learner:: PPOLearner )(env)
99
+ RLBase. get_prob (p:: PPOPolicy , env:: MultiThreadEnv ) = get_prob (p, get_state (env))
100
+
101
+ function RLBase. get_prob (p:: PPOPolicy , env:: AbstractEnv )
85
102
s = get_state (env)
86
103
s = Flux. unsqueeze (s, ndims (s) + 1 )
87
- s = send_to_device (device (learner. approximator), s)
88
- learner. approximator. actor (s) |> vec |> send_to_host
104
+ get_prob (p, s)[1 ]
89
105
end
90
106
91
- function RLBase. update! (learner:: PPOLearner , t:: PPOTrajectory )
107
+ (p:: PPOPolicy )(env:: MultiThreadEnv ) = rand .(p. rng, get_prob (p, env))
108
+ (p:: PPOPolicy )(env:: AbstractEnv ) = rand (p. rng, get_prob (p, env))
109
+
110
+ function RLBase. update! (p:: PPOPolicy , t:: PPOTrajectory )
92
111
isfull (t) || return
93
112
94
113
states = t[:state ]
@@ -98,16 +117,16 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
98
117
terminals = t[:terminal ]
99
118
states_plus = t[:full_state ]
100
119
101
- rng = learner . rng
102
- AC = learner . approximator
103
- γ = learner . γ
104
- λ = learner . λ
105
- n_epochs = learner . n_epochs
106
- n_microbatches = learner . n_microbatches
107
- clip_range = learner . clip_range
108
- w₁ = learner . actor_loss_weight
109
- w₂ = learner . critic_loss_weight
110
- w₃ = learner . entropy_loss_weight
120
+ rng = p . rng
121
+ AC = p . approximator
122
+ γ = p . γ
123
+ λ = p . λ
124
+ n_epochs = p . n_epochs
125
+ n_microbatches = p . n_microbatches
126
+ clip_range = p . clip_range
127
+ w₁ = p . actor_loss_weight
128
+ w₂ = p . critic_loss_weight
129
+ w₃ = p . entropy_loss_weight
111
130
D = device (AC)
112
131
113
132
n_envs, n_rollout = size (terminals)
@@ -142,60 +161,63 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
142
161
ps = Flux. params (AC)
143
162
gs = gradient (ps) do
144
163
v′ = AC. critic (s) |> vec
145
- logit′ = AC. actor (s)
146
- p′ = softmax (logit′)
147
- log_p′ = logsoftmax (logit′)
148
- log_p′ₐ = log_p′[CartesianIndex .(a, 1 : length (a))]
164
+ if AC. actor isa NeuralNetworkApproximator{<: GaussianNetwork }
165
+ μ, σ = AC. actor (s)
166
+ log_p′ₐ = normlogpdf (μ, σ, a)
167
+ entropy_loss = mean ((log (2.0f0 π)+ 1 )/ 2 .+ log .(σ))
168
+ else
169
+ # actor is assumed to return discrete logits
170
+ logit′ = AC. actor (s)
171
+ p′ = softmax (logit′)
172
+ log_p′ = logsoftmax (logit′)
173
+ log_p′ₐ = log_p′[CartesianIndex .(a, 1 : length (a))]
174
+ entropy_loss = - sum (p′ .* log_p′) * 1 // size (p′, 2 )
175
+ end
149
176
150
177
ratio = exp .(log_p′ₐ .- log_p)
151
178
surr1 = ratio .* adv
152
179
surr2 = clamp .(ratio, 1.0f0 - clip_range, 1.0f0 + clip_range) .* adv
153
180
154
181
actor_loss = - mean (min .(surr1, surr2))
155
182
critic_loss = mean ((r .- v′) .^ 2 )
156
- entropy_loss = - sum (p′ .* log_p′) * 1 // size (p′, 2 )
157
183
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
158
184
159
185
ignore () do
160
- learner . actor_loss[i, epoch] = actor_loss
161
- learner . critic_loss[i, epoch] = critic_loss
162
- learner . entropy_loss[i, epoch] = entropy_loss
163
- learner . loss[i, epoch] = loss
186
+ p . actor_loss[i, epoch] = actor_loss
187
+ p . critic_loss[i, epoch] = critic_loss
188
+ p . entropy_loss[i, epoch] = entropy_loss
189
+ p . loss[i, epoch] = loss
164
190
end
165
191
166
192
loss
167
193
end
168
194
169
- learner . norm[i, epoch] = clip_by_global_norm! (gs, ps, learner . max_grad_norm)
195
+ p . norm[i, epoch] = clip_by_global_norm! (gs, ps, p . max_grad_norm)
170
196
update! (AC, gs)
171
197
end
172
198
end
173
199
end
174
200
175
- function (π:: QBasedPolicy{<:PPOLearner} )(env:: MultiThreadEnv )
176
- action_values = π. learner (env)
177
- logits = logsoftmax (action_values)
178
- actions = π. explorer (action_values)
179
- actions_log_prob = logits[CartesianIndex .(actions, 1 : size (action_values, 2 ))]
180
- actions, actions_log_prob
181
- end
201
+ function (agent:: Agent{<:Union{PPOPolicy, RandomStartPolicy{<:PPOPolicy}}} )(:: Training{PreActStage} , env:: MultiThreadEnv )
202
+ state = get_state (env)
203
+ dist = get_prob (agent. policy, env)
182
204
183
- (π:: QBasedPolicy{<:PPOLearner} )(env) = env |> π. learner |> π. explorer
205
+ # currently RandomPolicy returns a Matrix instead of a (vector of) distribution.
206
+ if dist isa Matrix{<: Number }
207
+ dist = [Categorical (x;check_args= false ) for x in eachcol (dist)]
208
+ elseif dist isa Vector{<: Vector{<:Number} }
209
+ dist = [Categorical (x;check_args= false ) for x in dist]
210
+ end
184
211
185
- function (p:: RandomStartPolicy{<:QBasedPolicy{<:PPOLearner}} )(env:: MultiThreadEnv )
186
- p. num_rand_start -= 1
187
- if p. num_rand_start < 0
188
- p. policy (env)
189
- else
190
- a = p. random_policy (env)
191
- log_p = log .(get_prob (p. random_policy, env, a))
192
- a, log_p
212
+ # !!! a little ugly
213
+ rng = if agent. policy isa PPOPolicy
214
+ agent. policy. rng
215
+ elseif agent. policy isa RandomStartPolicy
216
+ agent. policy. policy. rng
193
217
end
194
- end
195
218
196
- function (agent:: Agent{<:AbstractPolicy,<:PPOTrajectory} )(:: Training{PreActStage} , env)
197
- action, action_log_prob = agent. policy (env)
198
- state = get_state (env)
219
+ action = [rand (rng, d) for d in dist]
220
+ action_log_prob = [logpdf (d, a) for (d, a) in zip (dist, action)]
199
221
push! (
200
222
agent. trajectory;
201
223
state = state,
@@ -217,12 +239,3 @@ function (agent::Agent{<:AbstractPolicy,<:PPOTrajectory})(::Training{PreActStage
217
239
218
240
action
219
241
end
220
-
221
- function (agent:: Agent{<:AbstractPolicy,<:PPOTrajectory} )(:: Training{PostActStage} , env)
222
- push! (agent. trajectory; reward = get_reward (env), terminal = get_terminal (env))
223
- nothing
224
- end
225
-
226
- function (agent:: Agent{<:AbstractPolicy,<:PPOTrajectory} )(:: Testing{PreActStage} , env)
227
- agent. policy (env)[1 ] # ignore the log_prob of action
228
- end
0 commit comments