Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 44e3358

Browse files
authored
Twin Delayed DDPG (TD3) (#89)
* add TD3 * adapt README
1 parent b7bb7a0 commit 44e3358

File tree

5 files changed

+306
-1
lines changed

5 files changed

+306
-1
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ This project aims to provide some implementations of the most typical reinforcem
2323
- A2C
2424
- PPO
2525
- DDPG
26+
- TD3
2627
- SAC
2728
- CFR
2829
- Minimax
@@ -44,6 +45,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
4445
- ``E`JuliaRL_A2CGAE_CartPole` `` (Thanks to [@sriram13m](https://github.com/sriram13m))
4546
- ``E`JuliaRL_PPO_CartPole` ``
4647
- ``E`JuliaRL_DDPG_Pendulum` ``
48+
- ``E`JuliaRL_TD3_Pendulum` `` (Thanks to [@rbange](https://github.com/rbange))
4749
- ``E`JuliaRL_SAC_Pendulum` `` (Thanks to [@rbange](https://github.com/rbange))
4850
- ``E`JuliaRL_BasicDQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
4951
- ``E`JuliaRL_DQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))

src/algorithms/policy_gradient/policy_gradient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ include("A2C.jl")
22
include("ppo.jl")
33
include("A2CGAE.jl")
44
include("ddpg.jl")
5+
include("td3.jl")
56
include("sac.jl")

src/algorithms/policy_gradient/td3.jl

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
export TD3Policy, TD3Critic
2+
3+
using Random
4+
using Flux
5+
6+
struct TD3Critic
7+
critic_1::Flux.Chain
8+
critic_2::Flux.Chain
9+
end
10+
Flux.@functor TD3Critic
11+
(c::TD3Critic)(s, a) = (inp = vcat(s, a); (c.critic_1(inp), c.critic_2(inp)))
12+
13+
mutable struct TD3Policy{
14+
BA<:NeuralNetworkApproximator,
15+
BC<:NeuralNetworkApproximator,
16+
TA<:NeuralNetworkApproximator,
17+
TC<:NeuralNetworkApproximator,
18+
P,
19+
R<:AbstractRNG,
20+
} <: AbstractPolicy
21+
22+
behavior_actor::BA
23+
behavior_critic::BC
24+
target_actor::TA
25+
target_critic::TC
26+
γ::Float32
27+
ρ::Float32
28+
batch_size::Int
29+
start_steps::Int
30+
start_policy::P
31+
update_after::Int
32+
update_every::Int
33+
policy_freq::Int
34+
target_act_limit::Float64
35+
target_act_noise::Float64
36+
act_limit::Float64
37+
act_noise::Float64
38+
step::Int
39+
rng::R
40+
replay_counter::Int
41+
# for logging
42+
actor_loss::Float32
43+
critic_loss::Float32
44+
end
45+
46+
"""
47+
TD3Policy(;kwargs...)
48+
49+
# Keyword arguments
50+
51+
- `behavior_actor`,
52+
- `behavior_critic`,
53+
- `target_actor`,
54+
- `target_critic`,
55+
- `start_policy`,
56+
- `γ = 0.99f0`,
57+
- `ρ = 0.995f0`,
58+
- `batch_size = 32`,
59+
- `start_steps = 10000`,
60+
- `update_after = 1000`,
61+
- `update_every = 50`,
62+
- `policy_freq = 2` # frequency in which the actor performs a gradient step and critic target is updated
63+
- `target_act_limit = 1.0`, # noise added to actor target
64+
- `target_act_noise = 0.1`, # noise added to actor target
65+
- `act_limit = 1.0`, # noise added when outputing action
66+
- `act_noise = 0.1`, # noise added when outputing action
67+
- `step = 0`,
68+
- `rng = Random.GLOBAL_RNG`,
69+
"""
70+
function TD3Policy(;
71+
behavior_actor,
72+
behavior_critic,
73+
target_actor,
74+
target_critic,
75+
start_policy,
76+
γ = 0.99f0,
77+
ρ = 0.995f0,
78+
batch_size = 64,
79+
start_steps = 10000,
80+
update_after = 1000,
81+
update_every = 50,
82+
policy_freq = 2,
83+
target_act_limit = 1.0,
84+
target_act_noise = 0.1,
85+
act_limit = 1.0,
86+
act_noise = 0.1,
87+
step = 0,
88+
rng = Random.GLOBAL_RNG,
89+
)
90+
copyto!(behavior_actor, target_actor) # force sync
91+
copyto!(behavior_critic, target_critic) # force sync
92+
TD3Policy(
93+
behavior_actor,
94+
behavior_critic,
95+
target_actor,
96+
target_critic,
97+
γ,
98+
ρ,
99+
batch_size,
100+
start_steps,
101+
start_policy,
102+
update_after,
103+
update_every,
104+
policy_freq,
105+
target_act_limit,
106+
target_act_noise,
107+
act_limit,
108+
act_noise,
109+
step,
110+
rng,
111+
1, # keep track of numbers of replay
112+
0.f0,
113+
0.f0,
114+
)
115+
end
116+
117+
# TODO: handle Training/Testing mode
118+
function (p::TD3Policy)(env)
119+
p.step += 1
120+
121+
if p.step <= p.start_steps
122+
p.start_policy(env)
123+
else
124+
D = device(p.behavior_actor)
125+
s = get_state(env)
126+
s = Flux.unsqueeze(s, ndims(s) + 1)
127+
action = p.behavior_actor(send_to_device(D, s)) |> vec |> send_to_host
128+
clamp(action[] + randn(p.rng) * p.act_noise, -p.act_limit, p.act_limit)
129+
end
130+
end
131+
132+
function RLBase.update!(p::TD3Policy, traj::CircularCompactSARTSATrajectory)
133+
length(traj[:terminal]) > p.update_after || return
134+
p.step % p.update_every == 0 || return
135+
136+
inds = rand(p.rng, 1:(length(traj[:terminal])-1), p.batch_size)
137+
s = select_last_dim(traj[:state], inds)
138+
a = select_last_dim(traj[:action], inds)
139+
r = select_last_dim(traj[:reward], inds)
140+
t = select_last_dim(traj[:terminal], inds)
141+
s′ = select_last_dim(traj[:next_state], inds)
142+
143+
actor = p.behavior_actor
144+
critic = p.behavior_critic
145+
146+
# !!! we have several assumptions here, need revisit when we have more complex environments
147+
# state is vector
148+
# action is scalar
149+
target_noise = clamp.(
150+
randn(p.rng, Float32, 1, p.batch_size) .* p.target_act_noise,
151+
-p.target_act_limit,
152+
p.target_act_limit,
153+
)
154+
# add noise and clip to tanh bounds
155+
a′ = clamp.(p.target_actor(s′) + target_noise, -1f0, 1f0)
156+
157+
q_1′, q_2′ = p.target_critic(s′, a′)
158+
y = r .+ p.γ .* (1 .- t) .* (min.(q_1′, q_2′) |> vec)
159+
a = Flux.unsqueeze(a, 1)
160+
161+
gs1 = gradient(Flux.params(critic)) do
162+
q1, q2 = critic(s, a)
163+
loss = mse(q1 |> vec, y) + mse(q2 |> vec, y)
164+
ignore() do
165+
p.critic_loss = loss
166+
end
167+
loss
168+
end
169+
update!(critic, gs1)
170+
171+
if p.replay_counter % p.policy_freq == 0
172+
gs2 = gradient(Flux.params(actor)) do
173+
actions = actor(s)
174+
loss = -mean(critic.model.critic_1(vcat(s, actions)))
175+
ignore() do
176+
p.actor_loss = loss
177+
end
178+
loss
179+
end
180+
update!(actor, gs2)
181+
# polyak averaging
182+
for (dest, src) in zip(Flux.params([p.target_actor, p.target_critic]), Flux.params([actor, critic]))
183+
dest .= p.ρ .* dest .+ (1 - p.ρ) .* src
184+
end
185+
p.replay_counter = 1
186+
end
187+
p.replay_counter += 1
188+
end

src/experiments/rl_envs.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,120 @@ function RLCore.Experiment(
800800
)
801801
end
802802

803+
function RLCore.Experiment(
804+
::Val{:JuliaRL},
805+
::Val{:TD3},
806+
::Val{:Pendulum},
807+
::Nothing;
808+
save_dir = nothing,
809+
seed = 123,
810+
)
811+
if isnothing(save_dir)
812+
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
813+
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_TD3_Pendulum_$(t)")
814+
end
815+
816+
lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
817+
rng = MersenneTwister(seed)
818+
inner_env = PendulumEnv(T = Float32, rng = rng)
819+
action_space = get_actions(inner_env)
820+
low = action_space.low
821+
high = action_space.high
822+
ns = length(get_state(inner_env))
823+
824+
env = inner_env |> ActionTransformedEnv(x -> low + (x + 1) * 0.5 * (high - low))
825+
init = glorot_uniform(rng)
826+
827+
create_actor() = Chain(
828+
Dense(ns, 30, relu; initW = init),
829+
Dense(30, 30, relu; initW = init),
830+
Dense(30, 1, tanh; initW = init),
831+
)
832+
833+
create_critic_model() = Chain(
834+
Dense(ns + 1, 30, relu; initW = init),
835+
Dense(30, 30, relu; initW = init),
836+
Dense(30, 1; initW = init),
837+
)
838+
839+
create_critic() = TD3Critic(create_critic_model(), create_critic_model())
840+
841+
agent = Agent(
842+
policy = TD3Policy(
843+
behavior_actor = NeuralNetworkApproximator(
844+
model = create_actor(),
845+
optimizer = ADAM(),
846+
),
847+
behavior_critic = NeuralNetworkApproximator(
848+
model = create_critic(),
849+
optimizer = ADAM(),
850+
),
851+
target_actor = NeuralNetworkApproximator(
852+
model = create_actor(),
853+
optimizer = ADAM(),
854+
),
855+
target_critic = NeuralNetworkApproximator(
856+
model = create_critic(),
857+
optimizer = ADAM(),
858+
),
859+
γ = 0.99f0,
860+
ρ = 0.99f0,
861+
batch_size = 64,
862+
start_steps = 1000,
863+
start_policy = RandomPolicy(ContinuousSpace(-1.0, 1.0); rng = rng),
864+
update_after = 1000,
865+
update_every = 1,
866+
policy_freq = 2,
867+
target_act_limit = 1.0,
868+
target_act_noise = 0.1,
869+
act_limit = 1.0,
870+
act_noise = 0.1,
871+
rng = rng,
872+
),
873+
trajectory = CircularCompactSARTSATrajectory(
874+
capacity = 10000,
875+
state_type = Float32,
876+
state_size = (ns,),
877+
action_type = Float32,
878+
),
879+
)
880+
881+
stop_condition = StopAfterStep(10_000)
882+
total_reward_per_episode = TotalRewardPerEpisode()
883+
time_per_step = TimePerStep()
884+
hook = ComposedHook(
885+
total_reward_per_episode,
886+
time_per_step,
887+
DoEveryNStep() do t, agent, env
888+
with_logger(lg) do
889+
@info(
890+
"training",
891+
actor_loss = agent.policy.actor_loss,
892+
critic_loss = agent.policy.critic_loss
893+
)
894+
end
895+
end,
896+
DoEveryNEpisode() do t, agent, env
897+
with_logger(lg) do
898+
@info "training" reward = total_reward_per_episode.rewards[end] log_step_increment =
899+
0
900+
end
901+
end,
902+
DoEveryNStep(10000) do t, agent, env
903+
RLCore.save(save_dir, agent)
904+
BSON.@save joinpath(save_dir, "stats.bson") total_reward_per_episode time_per_step
905+
end,
906+
)
907+
908+
Experiment(
909+
agent,
910+
env,
911+
stop_condition,
912+
hook,
913+
Description("# Play Pendulum with TD3", save_dir),
914+
)
915+
end
916+
803917
function RLCore.Experiment(
804918
::Val{:JuliaRL},
805919
::Val{:PPO},

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ using OpenSpiel
4848
mean(Iterators.flatten(res.hook[1].rewards))
4949
end
5050

51-
for method in (:DDPG, :SAC)
51+
for method in (:DDPG, :SAC, :TD3)
5252
res = run(Experiment(
5353
Val(:JuliaRL),
5454
Val(method),

0 commit comments

Comments
 (0)