Skip to content

Commit a6e40a0

Browse files
trainables
1 parent 43e51a2 commit a6e40a0

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

src/trainables.jl

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
1+
using BenchmarkTools
2+
using Optimisers
3+
using Functors
4+
using Zygote, Flux
25

36
function trainables1(x)
4-
isnumeric(x) && return [x]
7+
Optimisers.isnumeric(x) && return [x]
58
arrays = AbstractArray[]
6-
fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
9+
exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
10+
fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y
711
push!(arrays, y)
812
return y
913
end
@@ -17,19 +21,61 @@ using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
1721
struct TrainableWalk2 <: AbstractWalk end
1822

1923
function (walk::TrainableWalk2)(recurse, x, ys...)
20-
x_children = _values(Optimisers.trainable(x))
24+
x_children = Optimisers.trainable(x)
2125
ys_children = map(Optimisers.trainable, ys)
22-
res = _map(recurse, x_children, ys_children...)
23-
@show _values(res)
24-
return _values(res)
26+
res = map(recurse, x_children, ys_children...)
27+
return reduce(vcat, values(res),init=[])
2528
end
2629

2730
function trainables2(x)
2831
exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
29-
return execute(ExcludeWalk(TrainableWalk2(), x -> x, exclude), x)
32+
return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x)
33+
end
34+
35+
36+
struct TrainableWalk3 <: AbstractWalk end
37+
38+
function (walk::TrainableWalk3)(recurse, x, ys...)
39+
x_children = Optimisers.trainable(x)
40+
ys_children = map(Optimisers.trainable, ys)
41+
res = map(recurse, x_children, ys_children...)
42+
return vcat(values(res)...)
43+
end
44+
45+
function trainables3(x)
46+
exclude(x) = Optimisers.isnumeric(x)
47+
return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x)
48+
end
49+
50+
51+
function floss(ps)
52+
sum([sum(p) for p in ps])
3053
end
3154

3255
using Flux
3356

34-
m = Chain(Dense(2 => 3, relu), BatchNorm(3), Dense(3 => 2))
35-
trainables2(m)
57+
function perf()
58+
m = Chain(Dense(128 => 128, relu),
59+
Dense(128 => 128, relu),
60+
BatchNorm(128), Dense(3 => 2), x -> x^2)
61+
Dense(128 => 128, relu),
62+
Dense(128 => 128, relu)
63+
64+
println("trainables1")
65+
@btime trainables1($m)
66+
println("trainables2")
67+
@btime trainables2($m)
68+
println("trainables3")
69+
@btime trainables3($m)
70+
println()
71+
72+
73+
# gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating
74+
println("gradient trainables2")
75+
@btime gradient(m -> floss(trainables2(m)), $m)
76+
println("gradient trainables3")
77+
@btime gradient(m -> floss(trainables3(m)), $m)
78+
end
79+
80+
Zygote.refresh()
81+
perf()

0 commit comments

Comments
 (0)