Skip to content

Commit d6285e5

Browse files
cl/trainables
1 parent d11834c commit d6285e5

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

src/trainables.jl

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,33 @@ using BenchmarkTools
22
using Optimisers
33
using Functors
44
using Zygote, Flux
5+
using ChainRulesCore
56

67
function trainables1(x)
7-
Optimisers.isnumeric(x) && return [x]
88
arrays = AbstractArray[]
9-
exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
9+
exclude(x) = Optimisers.isnumeric(x)
1010
fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y
1111
push!(arrays, y)
1212
return y
1313
end
1414
return arrays
1515
end
1616

17+
function ∇trainables1(x, Δ)
18+
exclude(x) = Optimisers.isnumeric(x)
19+
i = 0
20+
return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _
21+
return Δ[i+=1]
22+
end
23+
end
24+
25+
26+
function ChainRulesCore.rrule(::typeof(trainables1), x)
27+
y = trainables1(x)
28+
trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ)))
29+
return y, trainables_back
30+
end
31+
1732
############
1833

1934
using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
@@ -49,33 +64,52 @@ end
4964

5065

5166
function floss(ps)
52-
sum([sum(p) for p in ps])
67+
sum([sum(abs2, p) for p in ps])
5368
end
5469

5570
using Flux
5671

5772
function perf()
5873
m = Chain(Dense(128 => 128, relu),
5974
Dense(128 => 128, relu),
60-
BatchNorm(128), Dense(3 => 2), x -> x^2)
75+
BatchNorm(128),
76+
x -> x^2,
6177
Dense(128 => 128, relu),
62-
Dense(128 => 128, relu)
78+
Dense(128 => 128, relu))
6379

6480
println("trainables1")
65-
@btime trainables1($m)
81+
@btime floss(trainables1($m))
6682
println("trainables2")
67-
@btime trainables2($m)
83+
@btime floss(trainables2($m))
6884
println("trainables3")
69-
@btime trainables3($m)
85+
@btime floss(trainables3($m))
7086
println()
7187

72-
73-
# gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating
88+
println("gradient trainables1")
89+
@btime gradient(m -> floss(trainables1(m)), $m)
7490
println("gradient trainables2")
7591
@btime gradient(m -> floss(trainables2(m)), $m)
7692
println("gradient trainables3")
7793
@btime gradient(m -> floss(trainables3(m)), $m)
94+
95+
nothing
7896
end
7997

8098
Zygote.refresh()
81-
perf()
99+
perf()
100+
101+
102+
m = Chain(Dense(128 => 128, relu),
103+
Dense(128 => 128, relu),
104+
BatchNorm(128),
105+
x -> x^2,
106+
Dense(128 => 128, relu),
107+
Dense(128 => 128, relu))
108+
109+
floss(trainables1(m))
110+
g1 = gradient(m -> floss(trainables1(m)), m)[1]
111+
g2 = gradient(m -> floss(trainables2(m)), m)[1]
112+
@test g1.layers[1].weight g2.layers[1].weight
113+
@test g1.layers[1].weight g2.layers[1].weight
114+
@test g1.layers[3].μ === nothing
115+
@test g2.layers[3].μ === nothing

0 commit comments

Comments
 (0)