1
-
1
+ using BenchmarkTools
2
+ using Optimisers
3
+ using Functors
4
+ using Zygote, Flux
2
5
3
6
function trainables1 (x)
4
- isnumeric (x) && return [x]
7
+ Optimisers . isnumeric (x) && return [x]
5
8
arrays = AbstractArray[]
6
- fmap (x; exclude = isnumeric, walk = _TrainableStructWalk ()) do y
9
+ exclude (x) = Optimisers. isnumeric (x) && Functors. isleaf (x)
10
+ fmap (x; exclude, walk = Optimisers. _TrainableStructWalk ()) do y
7
11
push! (arrays, y)
8
12
return y
9
13
end
@@ -17,19 +21,61 @@ using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
17
21
struct TrainableWalk2 <: AbstractWalk end
18
22
19
23
function (walk:: TrainableWalk2 )(recurse, x, ys... )
20
- x_children = _values ( Optimisers. trainable (x) )
24
+ x_children = Optimisers. trainable (x)
21
25
ys_children = map (Optimisers. trainable, ys)
22
- res = _map (recurse, x_children, ys_children... )
23
- @show _values (res)
24
- return _values (res)
26
+ res = map (recurse, x_children, ys_children... )
27
+ return reduce (vcat, values (res),init= [])
25
28
end
26
29
27
30
function trainables2 (x)
28
31
exclude (x) = Optimisers. isnumeric (x) && Functors. isleaf (x)
29
- return execute (ExcludeWalk (TrainableWalk2 (), x -> x, exclude), x)
32
+ return execute (ExcludeWalk (TrainableWalk2 (), x -> [x], exclude), x)
33
+ end
34
+
35
+
36
+ struct TrainableWalk3 <: AbstractWalk end
37
+
38
+ function (walk:: TrainableWalk3 )(recurse, x, ys... )
39
+ x_children = Optimisers. trainable (x)
40
+ ys_children = map (Optimisers. trainable, ys)
41
+ res = map (recurse, x_children, ys_children... )
42
+ return vcat (values (res)... )
43
+ end
44
+
45
+ function trainables3 (x)
46
+ exclude (x) = Optimisers. isnumeric (x)
47
+ return execute (ExcludeWalk (TrainableWalk3 (), x -> [x], exclude), x)
48
+ end
49
+
50
+
51
+ function floss (ps)
52
+ sum ([sum (p) for p in ps])
30
53
end
31
54
32
55
using Flux
33
56
34
- m = Chain (Dense (2 => 3 , relu), BatchNorm (3 ), Dense (3 => 2 ))
35
- trainables2 (m)
57
+ function perf ()
58
+ m = Chain (Dense (128 => 128 , relu),
59
+ Dense (128 => 128 , relu),
60
+ BatchNorm (128 ), Dense (3 => 2 ), x -> x^ 2 )
61
+ Dense (128 => 128 , relu),
62
+ Dense (128 => 128 , relu)
63
+
64
+ println (" trainables1" )
65
+ @btime trainables1 ($ m)
66
+ println (" trainables2" )
67
+ @btime trainables2 ($ m)
68
+ println (" trainables3" )
69
+ @btime trainables3 ($ m)
70
+ println ()
71
+
72
+
73
+ # gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating
74
+ println (" gradient trainables2" )
75
+ @btime gradient (m -> floss (trainables2 (m)), $ m)
76
+ println (" gradient trainables3" )
77
+ @btime gradient (m -> floss (trainables3 (m)), $ m)
78
+ end
79
+
80
+ Zygote. refresh ()
81
+ perf ()
0 commit comments