Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit ef712d0

Browse files
authored
fix atari related experiments (#145)
1 parent 0f13747 commit ef712d0

File tree

8 files changed

+141
-99
lines changed

8 files changed

+141
-99
lines changed

src/algorithms/policy_gradient/ppo.jl

+17-21
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ mutable struct PPOPolicy{A<:ActorCritic,D,R} <: AbstractPolicy
8585
critic_loss_weight::Float32
8686
entropy_loss_weight::Float32
8787
rng::R
88+
n_random_start::Int
8889
update_freq::Int
8990
update_step::Int
9091
# for logging
@@ -98,6 +99,7 @@ end
9899
function PPOPolicy(;
99100
approximator,
100101
update_freq,
102+
n_random_start = 0,
101103
update_step = 0,
102104
γ = 0.99f0,
103105
λ = 0.95f0,
@@ -123,6 +125,7 @@ function PPOPolicy(;
123125
critic_loss_weight,
124126
entropy_loss_weight,
125127
rng,
128+
n_random_start,
126129
update_freq,
127130
update_step,
128131
zeros(Float32, n_microbatches, n_epochs),
@@ -137,17 +140,25 @@ function RLBase.prob(
137140
p::PPOPolicy{<:ActorCritic{<:GaussianNetwork},Normal},
138141
state::AbstractArray,
139142
)
140-
p.approximator.actor(send_to_device(device(p.approximator), state)) |>
141-
send_to_host |>
142-
StructArray{Normal}
143+
if p.update_step < p.n_random_start
144+
@error "todo"
145+
else
146+
p.approximator.actor(send_to_device(device(p.approximator), state)) |>
147+
send_to_host |>
148+
StructArray{Normal}
149+
end
143150
end
144151

145152
function RLBase.prob(p::PPOPolicy{<:ActorCritic,Categorical}, state::AbstractArray)
146153
logits =
147154
p.approximator.actor(send_to_device(device(p.approximator), state)) |>
148155
softmax |>
149156
send_to_host
150-
[Categorical(x; check_args = false) for x in eachcol(logits)]
157+
if p.update_step < p.n_random_start
158+
[Categorical(fill(1/length(x), length(x)); check_args = false) for x in eachcol(logits)]
159+
else
160+
[Categorical(x; check_args = false) for x in eachcol(logits)]
161+
end
151162
end
152163

153164
RLBase.prob(p::PPOPolicy, env::MultiThreadEnv) = prob(p, state(env))
@@ -161,29 +172,14 @@ end
161172
(p::PPOPolicy)(env::MultiThreadEnv) = rand.(p.rng, prob(p, env))
162173
(p::PPOPolicy)(env::AbstractEnv) = rand(p.rng, prob(p, env))
163174

164-
function (agent::Agent{<:PPOPolicy})(env::AbstractEnv)
165-
dist = prob(agent.policy, env)
166-
a = rand(agent.policy.rng, dist)
167-
EnrichedAction(a; action_log_prob=logpdf(dist, a))
168-
end
169-
170175
function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
171176
dist = prob(agent.policy, env)
172177
action = rand.(agent.policy.rng, dist)
173178
EnrichedAction(action; action_log_prob=logpdf.(dist, action))
174179
end
175180

176-
function (agent::Agent{<:RandomStartPolicy{<:PPOPolicy}})(env::AbstractEnv)
177-
a = agent.policy(env)
178-
if a isa EnrichedAction
179-
a
180-
else
181-
EnrichedAction(a; action_log_prob=logpdf(prob(agent.policy, a)))
182-
end
183-
end
184-
185181
function RLBase.update!(p::PPOPolicy, t::Union{PPOTrajectory, MaskedPPOTrajectory}, ::AbstractEnv, ::PreActStage)
186-
length(t) == 0 && return # in the first update, only state & action is inserted into trajectory
182+
length(t) == 0 && return # in the first update, only state & action are inserted into trajectory
187183
p.update_step += 1
188184
if p.update_step % p.update_freq == 0
189185
_update!(p, t)
@@ -289,7 +285,7 @@ end
289285

290286
function RLBase.update!(
291287
trajectory::Union{PPOTrajectory,MaskedPPOTrajectory},
292-
policy::Union{PPOPolicy,RandomStartPolicy{<:PPOPolicy}},
288+
::PPOPolicy,
293289
env::MultiThreadEnv,
294290
::PreActStage,
295291
action::EnrichedAction

src/algorithms/policy_gradient/run.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function RLCore._run(
1313
)
1414

1515
while true
16-
reset!(env) # this is a soft reset!, only environments reached the end will get reset.
16+
reset!(env) # this is a soft reset!, only environments reached the end will be reset.
1717
action = policy(env)
1818
policy(PRE_ACT_STAGE, env, action)
1919
hook(PRE_ACT_STAGE, policy, env, action)
@@ -23,9 +23,11 @@ function RLCore._run(
2323
hook(POST_ACT_STAGE, policy, env)
2424

2525
if stop_condition(policy, env)
26-
policy(PRE_ACT_STAGE, env) # let the policy see the last observation
2726
break
2827
end
2928
end
30-
hook
29+
action = policy(env)
30+
policy(PRE_ACT_STAGE, env, action) # let the policy see the last observation
31+
hook(PRE_ACT_STAGE, policy, env, action)
32+
nothing
3133
end

src/experiments/atari/Dopamine_DQN_Atari.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function RLCore.Experiment(
7171
MAX_EPISODE_STEPS_EVAL = 27_000
7272
N_CHECKPOINTS = 3
7373

74-
total_reward_per_episode = TotalRewardPerEpisode()
74+
total_reward_per_episode = TotalOriginalRewardPerEpisode()
7575
time_per_step = TimePerStep()
7676
steps_per_episode = StepsPerEpisode()
7777
hook = ComposedHook(
@@ -93,7 +93,7 @@ function RLCore.Experiment(
9393
@info "evaluating agent at $t step..."
9494
p = agent.policy
9595
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
96-
h = ComposedHook(TotalRewardPerEpisode(), StepsPerEpisode())
96+
h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
9797
s = @elapsed run(
9898
p,
9999
atari_env_factory(

src/experiments/atari/Dopamine_IQN_Atari.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function RLCore.Experiment(
8686
MAX_EPISODE_STEPS_EVAL = 27_000
8787
N_CHECKPOINTS = 3
8888

89-
total_reward_per_episode = TotalRewardPerEpisode()
89+
total_reward_per_episode = TotalOriginalRewardPerEpisode()
9090
time_per_step = TimePerStep()
9191
steps_per_episode = StepsPerEpisode()
9292
hook = ComposedHook(
@@ -108,7 +108,7 @@ function RLCore.Experiment(
108108
@info "evaluating agent at $t step..."
109109
p = agent.policy
110110
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
111-
h = ComposedHook(TotalRewardPerEpisode(), StepsPerEpisode())
111+
h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
112112
s = @elapsed run(
113113
p,
114114
atari_env_factory(

src/experiments/atari/Dopamine_Rainbow_Atari.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ function RLCore.Experiment(
7474
MAX_EPISODE_STEPS_EVAL = 27_000
7575
N_CHECKPOINTS = 3
7676

77-
total_reward_per_episode = TotalRewardPerEpisode()
77+
total_reward_per_episode = TotalOriginalRewardPerEpisode()
7878
time_per_step = TimePerStep()
7979
steps_per_episode = StepsPerEpisode()
8080
hook = ComposedHook(
@@ -96,7 +96,7 @@ function RLCore.Experiment(
9696
@info "evaluating agent at $t step..."
9797
p = agent.policy
9898
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
99-
h = ComposedHook(TotalRewardPerEpisode(), StepsPerEpisode())
99+
h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
100100
s = @elapsed run(
101101
p,
102102
atari_env_factory(

src/experiments/atari/atari.jl

+56-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using .ArcadeLearningEnvironment
22
using .ReinforcementLearningEnvironments
3+
using BSON
4+
using Flux:Chain
35

46
function atari_env_factory(
57
name,
@@ -8,8 +10,9 @@ function atari_env_factory(
810
max_episode_steps = 100_000;
911
seed = nothing,
1012
repeat_action_probability = 0.25,
13+
n_replica = 1
1114
)
12-
AtariEnv(;
15+
init(seed) = AtariEnv(;
1316
name = string(name),
1417
grayscale_obs = true,
1518
noop_max = 30,
@@ -21,14 +24,62 @@ function atari_env_factory(
2124
full_action_space = false,
2225
seed = seed,
2326
) |>
24-
StateOverriddenEnv(
25-
ResizeImage(state_size...), # this implementation is different from cv2.resize https://github.com/google/dopamine/blob/e7d780d7c80954b7c396d984325002d60557f7d1/dopamine/discrete_domains/atari_lib.py#L629
26-
StackFrames(state_size..., n_frames),
27+
env -> StateOverriddenEnv(
28+
env,
29+
Chain(ResizeImage(state_size...), StackFrames(state_size..., n_frames))
2730
) |>
2831
StateCachedEnv |>
29-
RewardOverriddenEnv(r -> clamp(r, -1, 1))
32+
env -> RewardOverriddenEnv(env, r -> clamp(r, -1, 1))
33+
34+
if n_replica == 1
35+
init(seed)
36+
else
37+
envs = [init(hash(seed+i)) for i in 1:n_replica]
38+
states = Flux.batch(state.(envs))
39+
rewards = reward.(envs)
40+
terminals = is_terminated.(envs)
41+
A = Space([action_space(x) for x in envs])
42+
S = Space(fill(0..255, size(states)))
43+
MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing)
44+
end
45+
end
46+
47+
"Total reward per episode before reward reshaping"
48+
Base.@kwdef mutable struct TotalOriginalRewardPerEpisode <: AbstractHook
49+
rewards::Vector{Float64} = Float64[]
50+
reward::Float64 = 0.0
51+
end
52+
53+
function (hook::TotalOriginalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv)
54+
hook.reward += reward(env.env)
55+
end
56+
57+
function (hook::TotalOriginalRewardPerEpisode)(::PostEpisodeStage, agent, env)
58+
push!(hook.rewards, hook.reward)
59+
hook.reward = 0
3060
end
3161

62+
"Total reward of each inner env per episode before reward reshaping"
63+
struct TotalBatchOriginalRewardPerEpisode <: AbstractHook
64+
rewards::Vector{Vector{Float64}}
65+
reward::Vector{Float64}
66+
end
67+
68+
function TotalBatchOriginalRewardPerEpisode(batch_size::Int)
69+
TotalBatchOriginalRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size))
70+
end
71+
72+
function (hook::TotalBatchOriginalRewardPerEpisode)(::PostActStage, agent, env::MultiThreadEnv{<:RewardOverriddenEnv})
73+
for (i, e) in enumerate(env.envs)
74+
hook.reward[i] += reward(e.env)
75+
if is_terminated(e)
76+
push!(hook.rewards[i], hook.reward[i])
77+
hook.reward[i] = 0.0
78+
end
79+
end
80+
end
81+
82+
3283
for f in readdir(@__DIR__)
3384
if f != splitdir(@__FILE__)[2]
3485
include(f)

src/experiments/atari/rlpyt_A2C_Atari.jl

+19-21
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ function RLCore.Experiment(
1919
UPDATE_FREQ = 5
2020
N_FRAMES = 4
2121
STATE_SIZE = (80, 104)
22-
env = MultiThreadEnv([
23-
atari_env_factory(
24-
name,
25-
STATE_SIZE,
26-
N_FRAMES;
27-
repeat_action_probability = 0,
28-
seed = hash(seed + i),
29-
) for i in 1:N_ENV
30-
])
22+
env = atari_env_factory(
23+
name,
24+
STATE_SIZE,
25+
N_FRAMES;
26+
repeat_action_probability = 0,
27+
seed = seed,
28+
n_replica = N_ENV
29+
)
3130
N_ACTIONS = length(action_space(env[1]))
3231

3332
init = orthogonal(rng)
@@ -77,7 +76,7 @@ function RLCore.Experiment(
7776
N_CHECKPOINTS = 3
7877
stop_condition = StopAfterStep(N_TRAINING_STEPS)
7978

80-
total_batch_reward_per_episode = TotalBatchRewardPerEpisode(N_ENV)
79+
total_batch_reward_per_episode = TotalBatchOriginalRewardPerEpisode(N_ENV)
8180
batch_steps_per_episode = BatchStepsPerEpisode(N_ENV)
8281
evaluation_result = []
8382

@@ -112,19 +111,18 @@ function RLCore.Experiment(
112111
end,
113112
DoEveryNStep(EVALUATION_FREQ) do t, agent, env
114113
@info "evaluating agent at $t step..."
115-
h = TotalBatchRewardPerEpisode(N_ENV)
114+
h = TotalBatchOriginalRewardPerEpisode(N_ENV)
116115
s = @elapsed run(
117116
agent.policy,
118-
MultiThreadEnv([
119-
atari_env_factory(
120-
name,
121-
STATE_SIZE,
122-
N_FRAMES,
123-
MAX_EPISODE_STEPS_EVAL;
124-
repeat_action_probability = 0,
125-
seed = hash(seed + t + i),
126-
) for i in 1:N_ENV
127-
]),
117+
atari_env_factory(
118+
name,
119+
STATE_SIZE,
120+
N_FRAMES,
121+
MAX_EPISODE_STEPS_EVAL;
122+
repeat_action_probability = 0,
123+
seed = seed + t,
124+
n_replica = 4
125+
),
128126
StopAfterStep(27_000; is_show_progress = false),
129127
h,
130128
)

0 commit comments

Comments
 (0)