Skip to content

Commit 9f47a66

Browse files
trainables
1 parent d6285e5 commit 9f47a66

File tree

6 files changed

+145
-102
lines changed

6 files changed

+145
-102
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Such restrictions are also obeyed by this function for flattening a model:
5858
```@docs
5959
Optimisers.destructure
6060
Optimisers.Restructure
61+
Optimisers.trainables
6162
```
6263

6364
## Rule Definition

src/Optimisers.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ include("adjust.jl")
1111
include("destructure.jl")
1212
export destructure
1313

14+
include("trainables.jl")
15+
export trainables
16+
1417
include("rules.jl")
1518
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1619
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,

src/destructure.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function _flatten(x)
6666
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
6767
arrays = AbstractVector[]
6868
len = Ref(0)
69-
off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
69+
off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
7070
push!(arrays, _vec(y))
7171
o = len[]
7272
len[] = o + length(y)
@@ -76,9 +76,9 @@ function _flatten(x)
7676
return reduce(vcat, arrays), off, len[]
7777
end
7878

79-
struct _TrainableStructWalk <: AbstractWalk end
79+
struct TrainableStructWalk <: AbstractWalk end
8080

81-
(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
81+
(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
8282

8383
_vec(x::Number) = LinRange(x,x,1)
8484
_vec(x::AbstractArray) = vec(x)

src/trainables.jl

Lines changed: 35 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,52 @@
1-
using BenchmarkTools
2-
using Optimisers
3-
using Functors
4-
using Zygote, Flux
5-
using ChainRulesCore
61

7-
function trainables1(x)
2+
"""
3+
trainables(x)
4+
5+
Return an iterable over all the trainable parameters in `x`, that is all the numerical
6+
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).
7+
8+
Parameters appearing multiple times in the model will be present only once in the output.
9+
10+
See also [`destructure`](@ref).
11+
12+
# Examples
13+
14+
```jldoctest
15+
julia> struct MyLayer
16+
w
17+
b
18+
end
19+
20+
julia> Functors.@functor MyLayer
21+
22+
julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example
23+
24+
julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);
25+
26+
julia> trainables(x)
27+
1-element Vector{AbstractArray}:
28+
[1.0, 2.0, 3.0]
29+
"""
30+
function trainables(x)
831
arrays = AbstractArray[]
932
exclude(x) = Optimisers.isnumeric(x)
10-
fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y
33+
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
1134
push!(arrays, y)
1235
return y
1336
end
1437
return arrays
1538
end
1639

17-
function trainables1(x, Δ)
40+
function trainables(x, Δ)
1841
exclude(x) = Optimisers.isnumeric(x)
1942
i = 0
20-
return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _
43+
return fmapstructure(x; exclude, walk = Optimisers.TrainableStructWalk()) do _
2144
return Δ[i+=1]
2245
end
2346
end
2447

25-
26-
function ChainRulesCore.rrule(::typeof(trainables1), x)
27-
y = trainables1(x)
28-
trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ)))
48+
function ChainRulesCore.rrule(::typeof(trainables), x)
49+
y = trainables(x)
50+
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
2951
return y, trainables_back
3052
end
31-
32-
############
33-
34-
using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
35-
36-
struct TrainableWalk2 <: AbstractWalk end
37-
38-
function (walk::TrainableWalk2)(recurse, x, ys...)
39-
x_children = Optimisers.trainable(x)
40-
ys_children = map(Optimisers.trainable, ys)
41-
res = map(recurse, x_children, ys_children...)
42-
return reduce(vcat, values(res),init=[])
43-
end
44-
45-
function trainables2(x)
46-
exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
47-
return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x)
48-
end
49-
50-
51-
struct TrainableWalk3 <: AbstractWalk end
52-
53-
function (walk::TrainableWalk3)(recurse, x, ys...)
54-
x_children = Optimisers.trainable(x)
55-
ys_children = map(Optimisers.trainable, ys)
56-
res = map(recurse, x_children, ys_children...)
57-
return vcat(values(res)...)
58-
end
59-
60-
function trainables3(x)
61-
exclude(x) = Optimisers.isnumeric(x)
62-
return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x)
63-
end
64-
65-
66-
function floss(ps)
67-
sum([sum(abs2, p) for p in ps])
68-
end
69-
70-
using Flux
71-
72-
function perf()
73-
m = Chain(Dense(128 => 128, relu),
74-
Dense(128 => 128, relu),
75-
BatchNorm(128),
76-
x -> x^2,
77-
Dense(128 => 128, relu),
78-
Dense(128 => 128, relu))
79-
80-
println("trainables1")
81-
@btime floss(trainables1($m))
82-
println("trainables2")
83-
@btime floss(trainables2($m))
84-
println("trainables3")
85-
@btime floss(trainables3($m))
86-
println()
87-
88-
println("gradient trainables1")
89-
@btime gradient(m -> floss(trainables1(m)), $m)
90-
println("gradient trainables2")
91-
@btime gradient(m -> floss(trainables2(m)), $m)
92-
println("gradient trainables3")
93-
@btime gradient(m -> floss(trainables3(m)), $m)
94-
95-
nothing
96-
end
97-
98-
Zygote.refresh()
99-
perf()
100-
101-
102-
m = Chain(Dense(128 => 128, relu),
103-
Dense(128 => 128, relu),
104-
BatchNorm(128),
105-
x -> x^2,
106-
Dense(128 => 128, relu),
107-
Dense(128 => 128, relu))
108-
109-
floss(trainables1(m))
110-
g1 = gradient(m -> floss(trainables1(m)), m)[1]
111-
g2 = gradient(m -> floss(trainables2(m)), m)[1]
112-
@test g1.layers[1].weight g2.layers[1].weight
113-
@test g1.layers[1].weight g2.layers[1].weight
114-
@test g1.layers[3].μ === nothing
115-
@test g2.layers[3].μ === nothing

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Random.seed!(1)
1010

1111
struct Foo; x; y; end
1212
Functors.@functor Foo
13-
Optimisers.trainable(x::Foo) = (x.y, x.x)
13+
Optimisers.trainable(x::Foo) = (; x.y, x.x)
1414

1515
struct TwoThirds a; b; c; end
1616
Functors.@functor TwoThirds (a, c)
@@ -539,6 +539,9 @@ end
539539
@testset verbose=true "Destructure" begin
540540
include("destructure.jl")
541541
end
542+
@testset verbose=true "Trainables" begin
543+
include("trainables.jl")
544+
end
542545
@testset verbose=true "Optimisation Rules" begin
543546
include("rules.jl")
544547
end

test/trainables.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
m1 = collect(1:3.0)
3+
m2 = (collect(1:3.0), collect(4:6.0))
4+
m3 = (x = m1, y = sin, z = collect(4:6.0))
5+
6+
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
7+
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
8+
m6 = (a = m1, b = [4.0 + im], c = m1)
9+
10+
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
11+
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
12+
13+
mat = Float32[4 6; 5 7]
14+
m9 = (a = m1, b = mat, c = [mat, m1])
15+
16+
@testset "trainables" begin
17+
ps = trainables(m1)
18+
@test ps isa Vector
19+
@test length(ps) == 1
20+
@test ps[1] == m1
21+
22+
ps = trainables(m2)
23+
@test ps isa Vector
24+
@test length(ps) == 2
25+
@test ps[1] == m2[1]
26+
@test ps[2] == m2[2]
27+
28+
ps = trainables(m3)
29+
@test length(ps) == 2
30+
@test ps[1] == 1:3
31+
@test ps[2] == 4:6
32+
33+
ps = trainables(m4)
34+
@test length(ps) == 2
35+
@test ps[1] == 1:3
36+
@test ps[2] == 4:6
37+
38+
ps = trainables(m5)
39+
@test length(ps) == 3
40+
@test ps[1] == 1:3
41+
@test ps[2] == 4:6
42+
@test ps[3] == 4:6
43+
44+
ps = trainables(m6)
45+
@test length(ps) == 2
46+
@test ps[1] == 1:3
47+
@test ps[2] == ComplexF64[4.0 + 1.0im]
48+
49+
ps = trainables(m7)
50+
@test length(ps) == 1
51+
@test ps[1] == [1.0, 2.0, 3.0]
52+
53+
ps = trainables(m8)
54+
@test length(ps) == 3
55+
@test ps[1] == 1:3
56+
@test ps[2] == [4.0]
57+
@test ps[3] == [5.0]
58+
59+
ps = trainables(m9)
60+
@test length(ps) == 2
61+
@test ps[1] == 1:3
62+
@test ps[2] == mat
63+
end
64+
65+
@testset "gradient" begin
66+
loss(m) = sum([sum(abs2, p) for p in trainables(m)])
67+
g = gradient(loss, m1)[1]
68+
@test g == [2.0, 4.0, 6.0]
69+
70+
g = gradient(loss, m2)[1]
71+
@test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0])
72+
73+
g = gradient(loss, m3)[1]
74+
@test g.x == [2.0, 4.0, 6.0]
75+
@test g.y === nothing
76+
@test g.z == [8.0, 10.0, 12.0]
77+
78+
g = gradient(loss, m4)[1]
79+
@test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0])
80+
g.x === g.y # shared gradient for shared weights
81+
82+
g = gradient(loss, m5)[1]
83+
@test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing))
84+
85+
g = gradient(loss, m6)[1]
86+
@test g = (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0])
87+
88+
g = gradient(loss, m7)[1]
89+
@test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing)
90+
91+
g = gradient(loss, m8)[1]
92+
@test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0])
93+
@test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing)
94+
@test g[3] == [[10.0]]
95+
96+
g = gradient(loss, m9)[1]
97+
@test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]])
98+
end
99+

0 commit comments

Comments
 (0)