@@ -8,16 +8,15 @@ function RLCore.Experiment(
8
8
)
9
9
rng = Random. MersenneTwister (seed)
10
10
11
- SHAPE = (8 ,8 )
12
- inner_env = SnakeGameEnv (;action_style= FULL_ACTION_SET, shape= SHAPE, rng= rng)
11
+ SHAPE = (8 , 8 )
12
+ inner_env = SnakeGameEnv (; action_style = FULL_ACTION_SET, shape = SHAPE, rng = rng)
13
13
14
14
board_size = size (get_state (inner_env))
15
15
N_FRAMES = 4
16
16
17
- env = inner_env |>
18
- StateOverriddenEnv (
19
- StackFrames (board_size... , N_FRAMES),
20
- ) |>
17
+ env =
18
+ inner_env |>
19
+ StateOverriddenEnv (StackFrames (board_size... , N_FRAMES),) |>
21
20
StateCachedEnv
22
21
23
22
N_ACTIONS = length (get_actions (env))
@@ -36,7 +35,14 @@ function RLCore.Experiment(
36
35
create_model () =
37
36
Chain (
38
37
x -> reshape (x, SHAPE... , :, size (x, ndims (x))),
39
- CrossCor ((3 , 3 ), board_size[end ] * N_FRAMES => 16 , relu; stride = 1 , pad = 1 , init = init),
38
+ CrossCor (
39
+ (3 , 3 ),
40
+ board_size[end ] * N_FRAMES => 16 ,
41
+ relu;
42
+ stride = 1 ,
43
+ pad = 1 ,
44
+ init = init,
45
+ ),
40
46
CrossCor ((3 , 3 ), 16 => 32 , relu; stride = 1 , pad = 1 , init = init),
41
47
x -> reshape (x, :, size (x, ndims (x))),
42
48
Dense (8 * 8 * 32 , 256 , relu; initW = init),
@@ -45,35 +51,35 @@ function RLCore.Experiment(
45
51
46
52
agent = Agent (
47
53
policy = QBasedPolicy (
48
- learner = DQNLearner (
49
- approximator = NeuralNetworkApproximator (
50
- model = create_model (),
51
- optimizer = ADAM (0.001 ),
52
- ),
53
- target_approximator = NeuralNetworkApproximator (model = create_model ()),
54
- update_freq = update_freq,
55
- γ = 0.99f0 ,
56
- update_horizon = 1 ,
57
- batch_size = 32 ,
58
- stack_size = nothing ,
59
- min_replay_history = 20_000 ,
60
- loss_func = huber_loss,
61
- target_update_freq = 8_000 ,
62
- rng = rng,
63
- ),
64
- explorer = EpsilonGreedyExplorer (
65
- ϵ_init = 1.0 ,
66
- ϵ_stable = 0.01 ,
67
- decay_steps = 250_000 ,
68
- kind = :linear ,
69
- rng = rng,
54
+ learner = DQNLearner (
55
+ approximator = NeuralNetworkApproximator (
56
+ model = create_model (),
57
+ optimizer = ADAM (0.001 ),
70
58
),
59
+ target_approximator = NeuralNetworkApproximator (model = create_model ()),
60
+ update_freq = update_freq,
61
+ γ = 0.99f0 ,
62
+ update_horizon = 1 ,
63
+ batch_size = 32 ,
64
+ stack_size = nothing ,
65
+ min_replay_history = 20_000 ,
66
+ loss_func = huber_loss,
67
+ target_update_freq = 8_000 ,
68
+ rng = rng,
71
69
),
70
+ explorer = EpsilonGreedyExplorer (
71
+ ϵ_init = 1.0 ,
72
+ ϵ_stable = 0.01 ,
73
+ decay_steps = 250_000 ,
74
+ kind = :linear ,
75
+ rng = rng,
76
+ ),
77
+ ),
72
78
trajectory = CircularCompactSALRTSALTrajectory (
73
79
capacity = 500_000 ,
74
80
state_type = Float32,
75
81
state_size = (board_size... , N_FRAMES),
76
- legal_actions_mask_size= (N_ACTIONS,)
82
+ legal_actions_mask_size = (N_ACTIONS,),
77
83
),
78
84
)
79
85
@@ -90,7 +96,8 @@ function RLCore.Experiment(
90
96
steps_per_episode,
91
97
DoEveryNStep (update_freq) do t, agent, env
92
98
with_logger (lg) do
93
- @info " training" loss = agent. policy. learner. loss log_step_increment = update_freq
99
+ @info " training" loss = agent. policy. learner. loss log_step_increment =
100
+ update_freq
94
101
end
95
102
end ,
96
103
DoEveryNEpisode () do t, agent, env
@@ -109,4 +116,4 @@ function RLCore.Experiment(
109
116
You can view the tensorboard logs with `tensorboard --logdir $(joinpath (save_dir, " tb_log" )) `
110
117
"""
111
118
Experiment (agent, env, stop_condition, hook, description)
112
- end
119
+ end
0 commit comments