1
- import CUDA, DEQExperiments, FluxMPI, Logging, Lux, OneHotArrays, Optimisers, PyCall,
2
- Random, Setfield, SimpleConfig, Statistics, Wandb
1
+ import CUDA,
2
+ DEQExperiments,
3
+ FluxMPI,
4
+ Logging,
5
+ Lux,
6
+ OneHotArrays,
7
+ Optimisers,
8
+ PyCall,
9
+ Random,
10
+ Setfield,
11
+ SimpleConfig,
12
+ Statistics,
13
+ Wandb
3
14
import Lux. Training
4
15
import ComponentArrays as CA
5
16
@@ -10,8 +21,10 @@ function get_dataloaders(; augment, data_root, eval_batchsize, train_batchsize)
10
21
11
22
tf. config. set_visible_devices ([], " GPU" )
12
23
13
- ds_train, ds_test = tfds. load (" cifar10" ; split= [" train" , " test" ], as_supervised= true ,
14
- data_dir= data_root)
24
+ ds_train, ds_test = tfds. load (" cifar10" ;
25
+ split= [" train" , " test" ],
26
+ as_supervised= true ,
27
+ data_dir= data_root)
15
28
16
29
image_mean = tf. constant ([[[0.4914f0 , 0.4822f0 , 0.4465f0 ]]])
17
30
image_std = tf. constant ([[[0.2023f0 , 0.1994f0 , 0.2010f0 ]]])
@@ -50,12 +63,12 @@ function get_dataloaders(; augment, data_root, eval_batchsize, train_batchsize)
50
63
ds_test = ds_test. prefetch (tf. data. AUTOTUNE). repeat (1 )
51
64
52
65
return (tfds. as_numpy (ds_train. batch (train_batchsize)),
53
- tfds. as_numpy (ds_test. batch (eval_batchsize)))
66
+ tfds. as_numpy (ds_test. batch (eval_batchsize)))
54
67
end
55
68
56
69
function _data_postprocess (image, label)
57
70
return (Lux. gpu (permutedims (image, (3 , 2 , 4 , 1 ))),
58
- Lux. gpu (OneHotArrays. onehotbatch (label, 0 : 9 )))
71
+ Lux. gpu (OneHotArrays. onehotbatch (label, 0 : 9 )))
59
72
end
60
73
61
74
function main (filename, args)
@@ -85,12 +98,18 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
85
98
end
86
99
vjp_rule = Training. ZygoteVJP ()
87
100
88
- DEQExperiments. warmup_model (loss_function, model, tstate. parameters, tstate. states, cfg;
89
- transform_input= Lux. gpu)
90
-
91
- ds_train, ds_test = get_dataloaders (; cfg. dataset. augment, cfg. dataset. data_root,
92
- cfg. dataset. eval_batchsize,
93
- cfg. dataset. train_batchsize)
101
+ DEQExperiments. warmup_model (loss_function,
102
+ model,
103
+ tstate. parameters,
104
+ tstate. states,
105
+ cfg;
106
+ transform_input= Lux. gpu)
107
+
108
+ ds_train, ds_test = get_dataloaders (;
109
+ cfg. dataset. augment,
110
+ cfg. dataset. data_root,
111
+ cfg. dataset. eval_batchsize,
112
+ cfg. dataset. train_batchsize)
94
113
_, ds_train_iter = iterate (ds_train)
95
114
96
115
# Setup
@@ -123,9 +142,11 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
123
142
end
124
143
125
144
# Setup Logging
126
- loggers = DEQExperiments. create_logger (log_dir, cfg. train. total_steps - initial_step,
127
- cfg. train. total_steps - initial_step, expt_name,
128
- SimpleConfig. flatten_configuration (cfg))
145
+ loggers = DEQExperiments. create_logger (log_dir,
146
+ cfg. train. total_steps - initial_step,
147
+ cfg. train. total_steps - initial_step,
148
+ expt_name,
149
+ SimpleConfig. flatten_configuration (cfg))
129
150
130
151
best_test_accuracy = 0
131
152
@@ -145,7 +166,7 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
145
166
# LR Update
146
167
lr_new = sched (step + 1 )
147
168
Setfield. @set! tstate. optimizer_state = Optimisers. adjust (tstate. optimizer_state,
148
- lr_new)
169
+ lr_new)
149
170
150
171
accuracy = DEQExperiments. accuracy (Lux. cpu (stats. y_pred), Lux. cpu (y))
151
172
residual = abs (Statistics. mean (stats. residual))
@@ -154,7 +175,8 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
154
175
loggers. progress_loggers. train. avg_meters. batch_time (data_time +
155
176
step_stats. fwd_time +
156
177
step_stats. bwd_time +
157
- step_stats. opt_time, bsize)
178
+ step_stats. opt_time,
179
+ bsize)
158
180
loggers. progress_loggers. train. avg_meters. data_time (data_time, bsize)
159
181
loggers. progress_loggers. train. avg_meters. fwd_time (step_stats. fwd_time, bsize)
160
182
loggers. progress_loggers. train. avg_meters. bwd_time (step_stats. bwd_time, bsize)
@@ -208,7 +230,7 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
208
230
loggers. progress_loggers. eval. avg_meters. ce_loss (stats. ce_loss, bsize)
209
231
loggers. progress_loggers. eval. avg_meters. skip_loss (stats. skip_loss, bsize)
210
232
loggers. progress_loggers. eval. avg_meters. residual (abs (Statistics. mean (stats. residual)),
211
- bsize)
233
+ bsize)
212
234
loggers. progress_loggers. eval. avg_meters. top1 (acc, bsize)
213
235
loggers. progress_loggers. eval. avg_meters. top5 (- 1 , bsize)
214
236
loggers. progress_loggers. eval. avg_meters. nfe (stats. nfe, bsize)
@@ -233,8 +255,9 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
233
255
end
234
256
235
257
ckpt = (tstate= tstate, step= initial_step)
236
- DEQExperiments. save_checkpoint (ckpt; is_best,
237
- filename= joinpath (ckpt_dir, " model_$(step) .jlso" ))
258
+ DEQExperiments. save_checkpoint (ckpt;
259
+ is_best,
260
+ filename= joinpath (ckpt_dir, " model_$(step) .jlso" ))
238
261
end
239
262
end
240
263
0 commit comments