Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit be4bb82

Browse files
authored
add experiment of snake game (#77)
1 parent 74c6e5a commit be4bb82

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

src/ReinforcementLearningZoo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ function __init__()
1818
@require ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" begin
1919
include("experiments/rl_envs.jl")
2020
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari.jl")
21+
@require SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429" include("experiments/snake.jl")
2122
end
2223
end
2324

src/experiments/snake.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
function RLCore.Experiment(
2+
::Val{:JuliaRL},
3+
::Val{:DQN},
4+
::Val{:SnakeGame},
5+
::Nothing;
6+
save_dir = nothing,
7+
seed = 123,
8+
)
9+
rng = Random.MersenneTwister(seed)
10+
11+
SHAPE = (8,8)
12+
inner_env = SnakeGameEnv(;action_style=FULL_ACTION_SET, shape=SHAPE, rng=rng)
13+
14+
board_size = size(get_state(inner_env))
15+
N_FRAMES = 4
16+
17+
env = inner_env |>
18+
StateOverriddenEnv(
19+
StackFrames(board_size..., N_FRAMES),
20+
) |>
21+
StateCachedEnv
22+
23+
N_ACTIONS = length(get_actions(env))
24+
25+
if isnothing(save_dir)
26+
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
27+
save_dir = joinpath(pwd(), "checkpoints", "SnakeGame_$(t)")
28+
end
29+
30+
lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
31+
32+
init = glorot_uniform(rng)
33+
34+
update_freq = 4
35+
36+
create_model() =
37+
Chain(
38+
x -> reshape(x, SHAPE..., :, size(x, ndims(x))),
39+
CrossCor((3, 3), board_size[end] * N_FRAMES => 16, relu; stride = 1, pad = 1, init = init),
40+
CrossCor((3, 3), 16 => 32, relu; stride = 1, pad = 1, init = init),
41+
x -> reshape(x, :, size(x, ndims(x))),
42+
Dense(8 * 8 * 32, 256, relu; initW = init),
43+
Dense(256, N_ACTIONS; initW = init),
44+
) |> cpu
45+
46+
agent = Agent(
47+
policy = QBasedPolicy(
48+
learner = DQNLearner(
49+
approximator = NeuralNetworkApproximator(
50+
model = create_model(),
51+
optimizer = ADAM(0.001),
52+
),
53+
target_approximator = NeuralNetworkApproximator(model = create_model()),
54+
update_freq = update_freq,
55+
γ = 0.99f0,
56+
update_horizon = 1,
57+
batch_size = 32,
58+
stack_size = nothing,
59+
min_replay_history = 20_000,
60+
loss_func = huber_loss,
61+
target_update_freq = 8_000,
62+
rng = rng,
63+
),
64+
explorer = EpsilonGreedyExplorer(
65+
ϵ_init = 1.0,
66+
ϵ_stable = 0.01,
67+
decay_steps = 250_000,
68+
kind = :linear,
69+
rng = rng,
70+
),
71+
),
72+
trajectory = CircularCompactSALRTSALTrajectory(
73+
capacity = 500_000,
74+
state_type = Float32,
75+
state_size = (board_size..., N_FRAMES),
76+
legal_actions_mask_size=(N_ACTIONS,)
77+
),
78+
)
79+
80+
evaluation_result = []
81+
EVALUATION_FREQ = 100_000
82+
N_CHECKPOINTS = 3
83+
84+
total_reward_per_episode = TotalRewardPerEpisode()
85+
time_per_step = TimePerStep()
86+
steps_per_episode = StepsPerEpisode()
87+
hook = ComposedHook(
88+
total_reward_per_episode,
89+
time_per_step,
90+
steps_per_episode,
91+
DoEveryNStep(update_freq) do t, agent, env
92+
with_logger(lg) do
93+
@info "training" loss = agent.policy.learner.loss log_step_increment = update_freq
94+
end
95+
end,
96+
DoEveryNEpisode() do t, agent, env
97+
with_logger(lg) do
98+
@info "training" episode_length = steps_per_episode.steps[end] reward =
99+
total_reward_per_episode.rewards[end] log_step_increment = 0
100+
end
101+
end,
102+
)
103+
104+
N_TRAINING_STEPS = 1_000_000
105+
stop_condition = StopAfterStep(N_TRAINING_STEPS)
106+
description = """
107+
# Play Single Agent SnakeGame with DQN
108+
109+
You can view the tensorboard logs with `tensorboard --logdir $(joinpath(save_dir, "tb_log"))`
110+
"""
111+
Experiment(agent, env, stop_condition, hook, description)
112+
end

0 commit comments

Comments
 (0)