@@ -2,18 +2,33 @@ using BenchmarkTools
2
2
using Optimisers
3
3
using Functors
4
4
using Zygote, Flux
5
+ using ChainRulesCore
5
6
6
7
function trainables1 (x)
7
- Optimisers. isnumeric (x) && return [x]
8
8
arrays = AbstractArray[]
9
- exclude (x) = Optimisers. isnumeric (x) && Functors . isleaf (x)
9
+ exclude (x) = Optimisers. isnumeric (x)
10
10
fmap (x; exclude, walk = Optimisers. _TrainableStructWalk ()) do y
11
11
push! (arrays, y)
12
12
return y
13
13
end
14
14
return arrays
15
15
end
16
16
17
+ function ∇trainables1 (x, Δ)
18
+ exclude (x) = Optimisers. isnumeric (x)
19
+ i = 0
20
+ return fmapstructure (x; exclude, walk = Optimisers. _TrainableStructWalk ()) do _
21
+ return Δ[i+= 1 ]
22
+ end
23
+ end
24
+
25
+
26
+ function ChainRulesCore. rrule (:: typeof (trainables1), x)
27
+ y = trainables1 (x)
28
+ trainables_back (Δ) = (NoTangent (), ∇trainables1 (x, unthunk (Δ)))
29
+ return y, trainables_back
30
+ end
31
+
17
32
# ###########
18
33
19
34
using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
49
64
50
65
51
66
function floss (ps)
52
- sum ([sum (p) for p in ps])
67
+ sum ([sum (abs2, p) for p in ps])
53
68
end
54
69
55
70
using Flux
56
71
57
72
function perf ()
58
73
m = Chain (Dense (128 => 128 , relu),
59
74
Dense (128 => 128 , relu),
60
- BatchNorm (128 ), Dense (3 => 2 ), x -> x^ 2 )
75
+ BatchNorm (128 ),
76
+ x -> x^ 2 ,
61
77
Dense (128 => 128 , relu),
62
- Dense (128 => 128 , relu)
78
+ Dense (128 => 128 , relu))
63
79
64
80
println (" trainables1" )
65
- @btime trainables1 ($ m)
81
+ @btime floss ( trainables1 ($ m) )
66
82
println (" trainables2" )
67
- @btime trainables2 ($ m)
83
+ @btime floss ( trainables2 ($ m) )
68
84
println (" trainables3" )
69
- @btime trainables3 ($ m)
85
+ @btime floss ( trainables3 ($ m) )
70
86
println ()
71
87
72
-
73
- # gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating
88
+ println ( " gradient trainables1 " )
89
+ @btime gradient (m -> floss (trainables1 (m)), $ m)
74
90
println (" gradient trainables2" )
75
91
@btime gradient (m -> floss (trainables2 (m)), $ m)
76
92
println (" gradient trainables3" )
77
93
@btime gradient (m -> floss (trainables3 (m)), $ m)
94
+
95
+ nothing
78
96
end
79
97
80
98
Zygote. refresh ()
81
- perf ()
99
+ perf ()
100
+
101
+
102
+ m = Chain (Dense (128 => 128 , relu),
103
+ Dense (128 => 128 , relu),
104
+ BatchNorm (128 ),
105
+ x -> x^ 2 ,
106
+ Dense (128 => 128 , relu),
107
+ Dense (128 => 128 , relu))
108
+
109
+ floss (trainables1 (m))
110
+ g1 = gradient (m -> floss (trainables1 (m)), m)[1 ]
111
+ g2 = gradient (m -> floss (trainables2 (m)), m)[1 ]
112
+ @test g1. layers[1 ]. weight ≈ g2. layers[1 ]. weight
113
+ @test g1. layers[1 ]. weight ≈ g2. layers[1 ]. weight
114
+ @test g1. layers[3 ]. μ === nothing
115
+ @test g2. layers[3 ]. μ === nothing
0 commit comments