1
+ function RLCore. Experiment (
2
+ :: Val{:JuliaRL} ,
3
+ :: Val{:DQN} ,
4
+ :: Val{:SnakeGame} ,
5
+ :: Nothing ;
6
+ save_dir = nothing ,
7
+ seed = 123 ,
8
+ )
9
+ rng = Random. MersenneTwister (seed)
10
+
11
+ SHAPE = (8 ,8 )
12
+ inner_env = SnakeGameEnv (;action_style= FULL_ACTION_SET, shape= SHAPE, rng= rng)
13
+
14
+ board_size = size (get_state (inner_env))
15
+ N_FRAMES = 4
16
+
17
+ env = inner_env |>
18
+ StateOverriddenEnv (
19
+ StackFrames (board_size... , N_FRAMES),
20
+ ) |>
21
+ StateCachedEnv
22
+
23
+ N_ACTIONS = length (get_actions (env))
24
+
25
+ if isnothing (save_dir)
26
+ t = Dates. format (now (), " yyyy_mm_dd_HH_MM_SS" )
27
+ save_dir = joinpath (pwd (), " checkpoints" , " SnakeGame_$(t) " )
28
+ end
29
+
30
+ lg = TBLogger (joinpath (save_dir, " tb_log" ), min_level = Logging. Info)
31
+
32
+ init = glorot_uniform (rng)
33
+
34
+ update_freq = 4
35
+
36
+ create_model () =
37
+ Chain (
38
+ x -> reshape (x, SHAPE... , :, size (x, ndims (x))),
39
+ CrossCor ((3 , 3 ), board_size[end ] * N_FRAMES => 16 , relu; stride = 1 , pad = 1 , init = init),
40
+ CrossCor ((3 , 3 ), 16 => 32 , relu; stride = 1 , pad = 1 , init = init),
41
+ x -> reshape (x, :, size (x, ndims (x))),
42
+ Dense (8 * 8 * 32 , 256 , relu; initW = init),
43
+ Dense (256 , N_ACTIONS; initW = init),
44
+ ) |> cpu
45
+
46
+ agent = Agent (
47
+ policy = QBasedPolicy (
48
+ learner = DQNLearner (
49
+ approximator = NeuralNetworkApproximator (
50
+ model = create_model (),
51
+ optimizer = ADAM (0.001 ),
52
+ ),
53
+ target_approximator = NeuralNetworkApproximator (model = create_model ()),
54
+ update_freq = update_freq,
55
+ γ = 0.99f0 ,
56
+ update_horizon = 1 ,
57
+ batch_size = 32 ,
58
+ stack_size = nothing ,
59
+ min_replay_history = 20_000 ,
60
+ loss_func = huber_loss,
61
+ target_update_freq = 8_000 ,
62
+ rng = rng,
63
+ ),
64
+ explorer = EpsilonGreedyExplorer (
65
+ ϵ_init = 1.0 ,
66
+ ϵ_stable = 0.01 ,
67
+ decay_steps = 250_000 ,
68
+ kind = :linear ,
69
+ rng = rng,
70
+ ),
71
+ ),
72
+ trajectory = CircularCompactSALRTSALTrajectory (
73
+ capacity = 500_000 ,
74
+ state_type = Float32,
75
+ state_size = (board_size... , N_FRAMES),
76
+ legal_actions_mask_size= (N_ACTIONS,)
77
+ ),
78
+ )
79
+
80
+ evaluation_result = []
81
+ EVALUATION_FREQ = 100_000
82
+ N_CHECKPOINTS = 3
83
+
84
+ total_reward_per_episode = TotalRewardPerEpisode ()
85
+ time_per_step = TimePerStep ()
86
+ steps_per_episode = StepsPerEpisode ()
87
+ hook = ComposedHook (
88
+ total_reward_per_episode,
89
+ time_per_step,
90
+ steps_per_episode,
91
+ DoEveryNStep (update_freq) do t, agent, env
92
+ with_logger (lg) do
93
+ @info " training" loss = agent. policy. learner. loss log_step_increment = update_freq
94
+ end
95
+ end ,
96
+ DoEveryNEpisode () do t, agent, env
97
+ with_logger (lg) do
98
+ @info " training" episode_length = steps_per_episode. steps[end ] reward =
99
+ total_reward_per_episode. rewards[end ] log_step_increment = 0
100
+ end
101
+ end ,
102
+ )
103
+
104
+ N_TRAINING_STEPS = 1_000_000
105
+ stop_condition = StopAfterStep (N_TRAINING_STEPS)
106
+ description = """
107
+ # Play Single Agent SnakeGame with DQN
108
+
109
+ You can view the tensorboard logs with `tensorboard --logdir $(joinpath (save_dir, " tb_log" )) `
110
+ """
111
+ Experiment (agent, env, stop_condition, hook, description)
112
+ end
0 commit comments