Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 2674ba1

Browse files
authored
Rename PPOLearner to PPOPolicy and make it support continuous space (#93)
* initial changes * added experiment of PPO with Pendulum * fix test errors * update README * bump version and update dependency * fix conflict * avoid calculating distribution twice! * update dependency
1 parent 73d11f6 commit 2674ba1

File tree

7 files changed

+260
-126
lines changed

7 files changed

+260
-126
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReinforcementLearningZoo"
22
uuid = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
33
authors = ["Jun Tian <tianjun.cpp@gmail.com>"]
4-
version = "0.1.7"
4+
version = "0.2.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -21,6 +21,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2121
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2222
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2323
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
24+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2425
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
2526
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2627

@@ -31,8 +32,8 @@ CUDA = "1"
3132
Distributions = "0.23"
3233
Flux = "0.11"
3334
MacroTools = "0.5"
34-
ReinforcementLearningBase = "0.8"
35-
ReinforcementLearningCore = "0.4.2"
35+
ReinforcementLearningBase = "0.8.4"
36+
ReinforcementLearningCore = "0.4.5"
3637
Requires = "1"
3738
Setfield = "0.6, 0.7"
3839
StatsBase = "0.32, 0.33"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
5050
- ``E`JuliaRL_DDPG_Pendulum` ``
5151
- ``E`JuliaRL_TD3_Pendulum` `` (Thanks to [@rbange](https://github.com/rbange))
5252
- ``E`JuliaRL_SAC_Pendulum` `` (Thanks to [@rbange](https://github.com/rbange))
53+
- ``E`JuliaRL_PPO_Pendulum` ``
5354
- ``E`JuliaRL_BasicDQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
5455
- ``E`JuliaRL_DQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
5556
- ``E`JuliaRL_Minimax_OpenSpiel(tic_tac_toe)` ``
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
include("vpg.jl")
12
include("A2C.jl")
23
include("ppo.jl")
34
include("A2CGAE.jl")
45
include("ddpg.jl")
56
include("td3.jl")
67
include("sac.jl")
7-
include("vpg.jl")

src/algorithms/policy_gradient/ppo.jl

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
include("ppo_trajectory.jl")
22

33
using Random
4+
using Distributions: Categorical, Normal, logpdf
5+
using StructArrays
46

5-
export PPOLearner
7+
export PPOPolicy
68

79
"""
8-
PPOLearner(;kwargs)
10+
PPOPolicy(;kwargs)
911
1012
# Keyword arguments
1113
@@ -19,9 +21,13 @@ export PPOLearner
1921
- `actor_loss_weight = 1.0f0`,
2022
- `critic_loss_weight = 0.5f0`,
2123
- `entropy_loss_weight = 0.01f0`,
24+
- `dist = Categorical`,
2225
- `rng = Random.GLOBAL_RNG`,
26+
27+
By default, `dist` is set to `Categorical`, which means it will only works
28+
on environments of discrete actions. To work with environments of
2329
"""
24-
mutable struct PPOLearner{A<:ActorCritic,R} <: AbstractLearner
30+
mutable struct PPOPolicy{A<:ActorCritic,D,R} <: AbstractPolicy
2531
approximator::A
2632
γ::Float32
2733
λ::Float32
@@ -41,7 +47,7 @@ mutable struct PPOLearner{A<:ActorCritic,R} <: AbstractLearner
4147
loss::Matrix{Float32}
4248
end
4349

44-
function PPOLearner(;
50+
function PPOPolicy(;
4551
approximator,
4652
γ = 0.99f0,
4753
λ = 0.95f0,
@@ -52,9 +58,10 @@ function PPOLearner(;
5258
actor_loss_weight = 1.0f0,
5359
critic_loss_weight = 0.5f0,
5460
entropy_loss_weight = 0.01f0,
61+
dist = Categorical,
5562
rng = Random.GLOBAL_RNG,
5663
)
57-
PPOLearner(
64+
PPOPolicy{typeof(approximator),dist,typeof(rng)}(
5865
approximator,
5966
γ,
6067
λ,
@@ -74,21 +81,33 @@ function PPOLearner(;
7481
)
7582
end
7683

77-
function (learner::PPOLearner)(env::MultiThreadEnv)
78-
learner.approximator.actor(send_to_device(
79-
device(learner.approximator),
80-
get_state(env),
81-
)) |> send_to_host
84+
function RLBase.get_prob(p::PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}}, Normal}, state::AbstractArray)
85+
p.approximator.actor(send_to_device(
86+
device(p.approximator),
87+
state,
88+
)) |> send_to_host |> StructArray{Normal}
89+
end
90+
91+
function RLBase.get_prob(p::PPOPolicy{<:ActorCritic, Categorical}, state::AbstractArray)
92+
logits = p.approximator.actor(send_to_device(
93+
device(p.approximator),
94+
state,
95+
)) |> softmax |> send_to_host
96+
[Categorical(x;check_args=false) for x in eachcol(logits)]
8297
end
8398

84-
function (learner::PPOLearner)(env)
99+
RLBase.get_prob(p::PPOPolicy, env::MultiThreadEnv) = get_prob(p, get_state(env))
100+
101+
function RLBase.get_prob(p::PPOPolicy, env::AbstractEnv)
85102
s = get_state(env)
86103
s = Flux.unsqueeze(s, ndims(s) + 1)
87-
s = send_to_device(device(learner.approximator), s)
88-
learner.approximator.actor(s) |> vec |> send_to_host
104+
get_prob(p, s)[1]
89105
end
90106

91-
function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
107+
(p::PPOPolicy)(env::MultiThreadEnv) = rand.(p.rng, get_prob(p, env))
108+
(p::PPOPolicy)(env::AbstractEnv) = rand(p.rng, get_prob(p, env))
109+
110+
function RLBase.update!(p::PPOPolicy, t::PPOTrajectory)
92111
isfull(t) || return
93112

94113
states = t[:state]
@@ -98,16 +117,16 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
98117
terminals = t[:terminal]
99118
states_plus = t[:full_state]
100119

101-
rng = learner.rng
102-
AC = learner.approximator
103-
γ = learner.γ
104-
λ = learner.λ
105-
n_epochs = learner.n_epochs
106-
n_microbatches = learner.n_microbatches
107-
clip_range = learner.clip_range
108-
w₁ = learner.actor_loss_weight
109-
w₂ = learner.critic_loss_weight
110-
w₃ = learner.entropy_loss_weight
120+
rng = p.rng
121+
AC = p.approximator
122+
γ = p.γ
123+
λ = p.λ
124+
n_epochs = p.n_epochs
125+
n_microbatches = p.n_microbatches
126+
clip_range = p.clip_range
127+
w₁ = p.actor_loss_weight
128+
w₂ = p.critic_loss_weight
129+
w₃ = p.entropy_loss_weight
111130
D = device(AC)
112131

113132
n_envs, n_rollout = size(terminals)
@@ -142,60 +161,63 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
142161
ps = Flux.params(AC)
143162
gs = gradient(ps) do
144163
v′ = AC.critic(s) |> vec
145-
logit′ = AC.actor(s)
146-
p′ = softmax(logit′)
147-
log_p′ = logsoftmax(logit′)
148-
log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]
164+
if AC.actor isa NeuralNetworkApproximator{<:GaussianNetwork}
165+
μ, σ = AC.actor(s)
166+
log_p′ₐ = normlogpdf(μ, σ, a)
167+
entropy_loss = mean((log(2.0f0π)+1)/2 .+ log.(σ))
168+
else
169+
# actor is assumed to return discrete logits
170+
logit′ = AC.actor(s)
171+
p′ = softmax(logit′)
172+
log_p′ = logsoftmax(logit′)
173+
log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]
174+
entropy_loss = -sum(p′ .* log_p′) * 1//size(p′, 2)
175+
end
149176

150177
ratio = exp.(log_p′ₐ .- log_p)
151178
surr1 = ratio .* adv
152179
surr2 = clamp.(ratio, 1.0f0 - clip_range, 1.0f0 + clip_range) .* adv
153180

154181
actor_loss = -mean(min.(surr1, surr2))
155182
critic_loss = mean((r .- v′) .^ 2)
156-
entropy_loss = -sum(p′ .* log_p′) * 1//size(p′, 2)
157183
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
158184

159185
ignore() do
160-
learner.actor_loss[i, epoch] = actor_loss
161-
learner.critic_loss[i, epoch] = critic_loss
162-
learner.entropy_loss[i, epoch] = entropy_loss
163-
learner.loss[i, epoch] = loss
186+
p.actor_loss[i, epoch] = actor_loss
187+
p.critic_loss[i, epoch] = critic_loss
188+
p.entropy_loss[i, epoch] = entropy_loss
189+
p.loss[i, epoch] = loss
164190
end
165191

166192
loss
167193
end
168194

169-
learner.norm[i, epoch] = clip_by_global_norm!(gs, ps, learner.max_grad_norm)
195+
p.norm[i, epoch] = clip_by_global_norm!(gs, ps, p.max_grad_norm)
170196
update!(AC, gs)
171197
end
172198
end
173199
end
174200

175-
function::QBasedPolicy{<:PPOLearner})(env::MultiThreadEnv)
176-
action_values = π.learner(env)
177-
logits = logsoftmax(action_values)
178-
actions = π.explorer(action_values)
179-
actions_log_prob = logits[CartesianIndex.(actions, 1:size(action_values, 2))]
180-
actions, actions_log_prob
181-
end
201+
function (agent::Agent{<:Union{PPOPolicy, RandomStartPolicy{<:PPOPolicy}}})(::Training{PreActStage}, env::MultiThreadEnv)
202+
state = get_state(env)
203+
dist = get_prob(agent.policy, env)
182204

183-
::QBasedPolicy{<:PPOLearner})(env) = env |> π.learner |> π.explorer
205+
# currently RandomPolicy returns a Matrix instead of a (vector of) distribution.
206+
if dist isa Matrix{<:Number}
207+
dist = [Categorical(x;check_args=false) for x in eachcol(dist)]
208+
elseif dist isa Vector{<:Vector{<:Number}}
209+
dist = [Categorical(x;check_args=false) for x in dist]
210+
end
184211

185-
function (p::RandomStartPolicy{<:QBasedPolicy{<:PPOLearner}})(env::MultiThreadEnv)
186-
p.num_rand_start -= 1
187-
if p.num_rand_start < 0
188-
p.policy(env)
189-
else
190-
a = p.random_policy(env)
191-
log_p = log.(get_prob(p.random_policy, env, a))
192-
a, log_p
212+
# !!! a little ugly
213+
rng = if agent.policy isa PPOPolicy
214+
agent.policy.rng
215+
elseif agent.policy isa RandomStartPolicy
216+
agent.policy.policy.rng
193217
end
194-
end
195218

196-
function (agent::Agent{<:AbstractPolicy,<:PPOTrajectory})(::Training{PreActStage}, env)
197-
action, action_log_prob = agent.policy(env)
198-
state = get_state(env)
219+
action = [rand(rng, d) for d in dist]
220+
action_log_prob = [logpdf(d, a) for (d, a) in zip(dist, action)]
199221
push!(
200222
agent.trajectory;
201223
state = state,
@@ -217,12 +239,3 @@ function (agent::Agent{<:AbstractPolicy,<:PPOTrajectory})(::Training{PreActStage
217239

218240
action
219241
end
220-
221-
function (agent::Agent{<:AbstractPolicy,<:PPOTrajectory})(::Training{PostActStage}, env)
222-
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
223-
nothing
224-
end
225-
226-
function (agent::Agent{<:AbstractPolicy,<:PPOTrajectory})(::Testing{PreActStage}, env)
227-
agent.policy(env)[1] # ignore the log_prob of action
228-
end

src/algorithms/policy_gradient/vpg.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@ using ReinforcementLearningCore
66

77
export VPGPolicy, GaussianNetwork
88

9-
struct GaussianNetwork
10-
pre::Chain
11-
μ::Chain
12-
σ::Chain
9+
"""
10+
GaussianNetwork(;pre=identity, μ, σ)
11+
12+
`σ` should return the log of std, `exp` will be applied to it automatically.
13+
"""
14+
Base.@kwdef struct GaussianNetwork{P,U,S}
15+
pre::P = identity
16+
μ::U
17+
σ::S
1318
end
19+
1420
Flux.@functor GaussianNetwork
21+
1522
function (m::GaussianNetwork)(S)
1623
x = m.pre(S)
1724
m.μ(x), m.σ(x) .|> exp

src/experiments/atari.jl

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -753,25 +753,22 @@ function RLCore.Experiment(
753753
policy = RandomStartPolicy(
754754
num_rand_start = 1000,
755755
random_policy = RandomPolicy(get_actions(env); rng = rng),
756-
policy = QBasedPolicy(
757-
learner = PPOLearner(
758-
approximator = ActorCritic(
759-
actor = Chain(model, Dense(512, N_ACTIONS; initW = init)),
760-
critic = Chain(model, Dense(512, 1; initW = init)),
761-
optimizer = ADAM(INIT_LEARNING_RATE), # decrease learning rate with a hook
762-
) |> gpu,
763-
γ = 0.99f0,
764-
λ = 0.98f0,
765-
clip_range = INIT_CLIP_RANGE, # decrease with a hook
766-
max_grad_norm = 1.0f0,
767-
n_microbatches = 4,
768-
n_epochs = 4,
769-
actor_loss_weight = 1.0f0,
770-
critic_loss_weight = 0.5f0,
771-
entropy_loss_weight = 0.01f0,
772-
rng = rng,
773-
),
774-
explorer = BatchExplorer(GumbelSoftmaxExplorer(; rng = rng)),
756+
policy = PPOPolicy(
757+
approximator = ActorCritic(
758+
actor = Chain(model, Dense(512, N_ACTIONS; initW = init)),
759+
critic = Chain(model, Dense(512, 1; initW = init)),
760+
optimizer = ADAM(INIT_LEARNING_RATE), # decrease learning rate with a hook
761+
) |> gpu,
762+
γ = 0.99f0,
763+
λ = 0.98f0,
764+
clip_range = INIT_CLIP_RANGE, # decrease with a hook
765+
max_grad_norm = 1.0f0,
766+
n_microbatches = 4,
767+
n_epochs = 4,
768+
actor_loss_weight = 1.0f0,
769+
critic_loss_weight = 0.5f0,
770+
entropy_loss_weight = 0.01f0,
771+
rng = rng,
775772
),
776773
),
777774
trajectory = PPOTrajectory(;
@@ -803,19 +800,19 @@ function RLCore.Experiment(
803800
total_batch_reward_per_episode,
804801
batch_steps_per_episode,
805802
DoEveryNStep(UPDATE_FREQ) do t, agent, env
806-
learner = agent.policy.policy.learner
803+
p = agent.policy.policy
807804
with_logger(lg) do
808-
@info "training" loss = mean(learner.loss) actor_loss =
809-
mean(learner.actor_loss) critic_loss = mean(learner.critic_loss) entropy_loss =
810-
mean(learner.entropy_loss) norm = mean(learner.norm) log_step_increment =
805+
@info "training" loss = mean(p.loss) actor_loss =
806+
mean(p.actor_loss) critic_loss = mean(p.critic_loss) entropy_loss =
807+
mean(p.entropy_loss) norm = mean(p.norm) log_step_increment =
811808
UPDATE_FREQ
812809
end
813810
end,
814811
DoEveryNStep(UPDATE_FREQ) do t, agent, env
815812
decay = (N_TRAINING_STEPS - t) / N_TRAINING_STEPS
816-
agent.policy.policy.learner.approximator.optimizer.eta =
813+
agent.policy.policy.approximator.optimizer.eta =
817814
INIT_LEARNING_RATE * decay
818-
agent.policy.policy.learner.clip_range = INIT_CLIP_RANGE * Float32(decay)
815+
agent.policy.policy.clip_range = INIT_CLIP_RANGE * Float32(decay)
819816
end,
820817
DoEveryNStep() do t, agent, env
821818
with_logger(lg) do

0 commit comments

Comments
 (0)