Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit cf9bf19

Browse files
authored
Soft Actor Critic (#71)
* inital SAC implementation * PR review fixes
1 parent 6409d3a commit cf9bf19

File tree

6 files changed

+268
-3
lines changed

6 files changed

+268
-3
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1010
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
11+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1112
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<div align="center">
1+
<div align="center">
22
<a href="https://en.wikipedia.org/wiki/Tangram"> <img src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Tangram-man.svg" width="200"> </a>
33
<p> <a href="https://wiki.c2.com/?MakeItWorkMakeItRightMakeItFast">"Make It Work Make It Right Make It Fast"</a></p>
44
<p>― <a href="https://wiki.c2.com/?KentBeck">KentBeck</a></p>
@@ -23,6 +23,7 @@ This project aims to provide some implementations of the most typical reinforcem
2323
- A2C
2424
- PPO
2525
- DDPG
26+
- SAC
2627

2728
If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).
2829

@@ -41,6 +42,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
4142
- ``E`JuliaRL_A2CGAE_CartPole` `` (Thanks to [@sriram13m](https://github.com/sriram13m))
4243
- ``E`JuliaRL_PPO_CartPole` ``
4344
- ``E`JuliaRL_DDPG_Pendulum` ``
45+
- ``E`JuliaRL_SAC_Pendulum` `` (Thanks to [@rbange](https://github.com/rbange))
4446
- ``E`JuliaRL_BasicDQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
4547
- ``E`JuliaRL_DQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
4648
- ``E`Dopamine_DQN_Atari(pong)` ``

src/algorithms/policy_gradient/policy_gradient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("A2C.jl")
22
include("ppo.jl")
33
include("A2CGAE.jl")
44
include("ddpg.jl")
5+
include("sac.jl")

src/algorithms/policy_gradient/sac.jl

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
export SACPolicy, SACPolicyNetwork
2+
3+
using Random
4+
using Flux
5+
using Flux.Losses: mse
6+
using Distributions: Normal, logpdf
7+
8+
# Define SAC Actor
9+
struct SACPolicyNetwork
10+
pre::Chain
11+
mean::Chain
12+
log_std::Chain
13+
end
14+
Flux.@functor SACPolicyNetwork
15+
(m::SACPolicyNetwork)(state) = (x = m.pre(state); (m.mean(x), m.log_std(x)))
16+
17+
mutable struct SACPolicy{
18+
BA<:NeuralNetworkApproximator,
19+
BC1<:NeuralNetworkApproximator,
20+
BC2<:NeuralNetworkApproximator,
21+
P,
22+
R<:AbstractRNG,
23+
} <: AbstractPolicy
24+
25+
policy::BA
26+
qnetwork1::BC1
27+
qnetwork2::BC2
28+
target_qnetwork1::BC1
29+
target_qnetwork2::BC2
30+
γ::Float32
31+
ρ::Float32
32+
α::Float32
33+
batch_size::Int
34+
start_steps::Int
35+
start_policy::P
36+
update_after::Int
37+
update_every::Int
38+
step::Int
39+
rng::R
40+
end
41+
42+
"""
43+
SACPolicy(;kwargs...)
44+
45+
# Keyword arguments
46+
47+
- `policy`,
48+
- `qnetwork1`,
49+
- `qnetwork2`,
50+
- `target_qnetwork1`,
51+
- `target_qnetwork2`,
52+
- `start_policy`,
53+
- `γ = 0.99f0`,
54+
- `ρ = 0.995f0`,
55+
- `α = 0.2f0`,
56+
- `batch_size = 32`,
57+
- `start_steps = 10000`,
58+
- `update_after = 1000`,
59+
- `update_every = 50`,
60+
- `step = 0`,
61+
- `rng = Random.GLOBAL_RNG`,
62+
"""
63+
function SACPolicy(;
64+
policy,
65+
qnetwork1,
66+
qnetwork2,
67+
target_qnetwork1,
68+
target_qnetwork2,
69+
start_policy,
70+
γ = 0.99f0,
71+
ρ = 0.995f0,
72+
α = 0.2f0,
73+
batch_size = 32,
74+
start_steps = 10000,
75+
update_after = 1000,
76+
update_every = 50,
77+
step = 0,
78+
rng = Random.GLOBAL_RNG,
79+
)
80+
copyto!(qnetwork1, target_qnetwork1) # force sync
81+
copyto!(qnetwork2, target_qnetwork2) # force sync
82+
SACPolicy(
83+
policy,
84+
qnetwork1,
85+
qnetwork2,
86+
target_qnetwork1,
87+
target_qnetwork2,
88+
γ,
89+
ρ,
90+
α,
91+
batch_size,
92+
start_steps,
93+
start_policy,
94+
update_after,
95+
update_every,
96+
step,
97+
rng,
98+
)
99+
end
100+
101+
# TODO: handle Training/Testing mode
102+
function (p::SACPolicy)(env)
103+
p.step += 1
104+
105+
if p.step <= p.start_steps
106+
p.start_policy(env)
107+
else
108+
D = device(p.policy)
109+
s = get_state(env)
110+
s = Flux.unsqueeze(s, ndims(s) + 1)
111+
# trainmode:
112+
action = evaluate(p, s)[1][] # returns action as scalar
113+
114+
# testmode:
115+
# if testing dont sample an action, but act deterministically by
116+
# taking the "mean" action
117+
# action = p.policy(s)[1][] # returns action as scalar
118+
end
119+
end
120+
121+
"""
122+
This function is compatible with a multidimensional action space.
123+
"""
124+
function evaluate(p::SACPolicy, state)
125+
μ, log_σ = p.policy(state)
126+
π_dist = Normal.(μ, exp.(log_σ))
127+
z = rand.(p.rng, π_dist)
128+
logp_π = sum(logpdf.(π_dist, z), dims = 1)
129+
logp_π -= sum((2f0 .* (log(2f0) .- z - softplus.(-2f0 * z))), dims = 1)
130+
return tanh.(z), logp_π
131+
end
132+
133+
function RLBase.update!(p::SACPolicy, traj::CircularCompactSARTSATrajectory)
134+
length(traj[:terminal]) > p.update_after || return
135+
p.step % p.update_every == 0 || return
136+
137+
inds = rand(p.rng, 1:(length(traj[:terminal])-1), p.batch_size)
138+
s = select_last_dim(traj[:state], inds)
139+
a = select_last_dim(traj[:action], inds)
140+
r = select_last_dim(traj[:reward], inds)
141+
t = select_last_dim(traj[:terminal], inds)
142+
s′ = select_last_dim(traj[:next_state], inds)
143+
144+
γ, ρ, α = p.γ, p.ρ, p.α
145+
146+
# !!! we have several assumptions here, need revisit when we have more complex environments
147+
# state is vector
148+
# action is scalar
149+
a′, log_π = evaluate(p, s′)
150+
q′_input = vcat(s′, a′)
151+
q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input))
152+
153+
y = r .+ γ .* (1 .- t) .* vec((q′ .- α .* log_π))
154+
155+
# Train Q Networks
156+
a = Flux.unsqueeze(a, 1)
157+
q_input = vcat(s, a)
158+
159+
q_grad_1 = gradient(Flux.params(p.qnetwork1)) do
160+
q1 = p.qnetwork1(q_input) |> vec
161+
mse(q1, y)
162+
end
163+
update!(p.qnetwork1, q_grad_1)
164+
q_grad_2 = gradient(Flux.params(p.qnetwork2)) do
165+
q2 = p.qnetwork1(q_input) |> vec
166+
mse(q2, y)
167+
end
168+
update!(p.qnetwork2, q_grad_2)
169+
170+
# Train Policy
171+
p_grad = gradient(Flux.params(p.policy)) do
172+
a, log_π = evaluate(p, s)
173+
q_input = vcat(s, a)
174+
q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input))
175+
mean.* log_π .- q)
176+
end
177+
update!(p.policy, p_grad)
178+
179+
# polyak averaging
180+
for (dest, src) in zip(
181+
Flux.params([p.target_qnetwork1, p.target_qnetwork2]),
182+
Flux.params([p.qnetwork1, p.qnetwork2]),
183+
)
184+
dest .= ρ .* dest .+ (1 - ρ) .* src
185+
end
186+
end

src/experiments/rl_envs.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,3 +954,75 @@ function RLCore.Experiment(
954954

955955
Experiment(agent, env, stop_condition, hook, description)
956956
end
957+
958+
function RLCore.Experiment(
959+
::Val{:JuliaRL},
960+
::Val{:SAC},
961+
::Val{:Pendulum},
962+
::Nothing;
963+
seed = 123,
964+
)
965+
rng = MersenneTwister(seed)
966+
inner_env = PendulumEnv(T = Float32, rng = rng)
967+
action_space = get_actions(inner_env)
968+
low = action_space.low
969+
high = action_space.high
970+
ns = length(get_state(inner_env))
971+
972+
env = inner_env |> ActionTransformedEnv(x -> low + (x + 1) * 0.5 * (high - low))
973+
init = glorot_uniform(rng)
974+
975+
create_policy_net() = NeuralNetworkApproximator(
976+
model = SACPolicyNetwork(
977+
Chain(Dense(ns, 30, relu), Dense(30, 30, relu)),
978+
Chain(Dense(30, 1, initW = init)),
979+
Chain(Dense(
980+
30,
981+
1,
982+
x -> min(max(x, typeof(x)(-20)), typeof(x)(2)),
983+
initW = init,
984+
)),
985+
),
986+
optimizer = ADAM(0.003),
987+
)
988+
989+
create_q_net() = NeuralNetworkApproximator(
990+
model = Chain(
991+
Dense(ns + 1, 30, relu; initW = init),
992+
Dense(30, 30, relu; initW = init),
993+
Dense(30, 1; initW = init),
994+
),
995+
optimizer = ADAM(0.003),
996+
)
997+
998+
agent = Agent(
999+
policy = SACPolicy(
1000+
policy = create_policy_net(),
1001+
qnetwork1 = create_q_net(),
1002+
qnetwork2 = create_q_net(),
1003+
target_qnetwork1 = create_q_net(),
1004+
target_qnetwork2 = create_q_net(),
1005+
γ = 0.99f0,
1006+
ρ = 0.995f0,
1007+
α = 0.2f0,
1008+
batch_size = 64,
1009+
start_steps = 1000,
1010+
start_policy = RandomPolicy(ContinuousSpace(-1.0, 1.0); rng = rng),
1011+
update_after = 1000,
1012+
update_every = 1,
1013+
rng = rng,
1014+
),
1015+
trajectory = CircularCompactSARTSATrajectory(
1016+
capacity = 10000,
1017+
state_type = Float32,
1018+
state_size = (ns,),
1019+
action_type = Float32,
1020+
),
1021+
)
1022+
1023+
description = """
1024+
# Play Pendulum with SAC
1025+
"""
1026+
1027+
Experiment(agent, env, StopAfterStep(10000), TotalRewardPerEpisode(), description)
1028+
end

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ using Random
4141
mean(Iterators.flatten(res.hook.rewards))
4242
end
4343

44-
res = run(E`JuliaRL_DDPG_Pendulum`)
45-
@info "stats for DDPG Pendulum" avg_reward = mean(res.hook.rewards)
44+
for method in (:DDPG, :SAC)
45+
res = run(Experiment(Val(:JuliaRL), Val(method), Val(:Pendulum), nothing))
46+
@info "stats for $method" avg_reward =
47+
mean(Iterators.flatten(res.hook.rewards))
48+
end
4649
end
4750
end
4851

0 commit comments

Comments
 (0)