Skip to content

add trainables #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Such restrictions are also obeyed by this function for flattening a model:
```@docs
Optimisers.destructure
Optimisers.Restructure
Optimisers.trainables
```

## Rule Definition
Expand Down
26 changes: 26 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,29 @@ flat, re = destructure(params)
end
```

## Collecting all trainable parameters

Sometimes it is useful to collect all trainable parameters in a model,
similarly to what [`destructure`](@ref Optimisers.destructure) does but without
concatenating the arrays into a flat vector.
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:

```julia
julia> using Flux, Optimisers

julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));

julia> trainables(model)
6-element Vector{AbstractArray}:
Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061]
Float32[0.0, 0.0, 0.0]
Float32[0.0, 0.0, 0.0]
Float32[1.0, 1.0, 1.0]
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
Float32[0.0, 0.0]

julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);

julia> g = gradient(l2reg, model)[1];
```
Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not.
3 changes: 3 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ include("adjust.jl")
include("destructure.jl")
export destructure

include("trainables.jl")
export trainables

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
Expand Down
9 changes: 5 additions & 4 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
end
isempty(arrays) && return Bool[], off, 0
reduce(vcat, arrays), off, len[]
return reduce(vcat, arrays), off, len[]
end

struct _TrainableStructWalk <: AbstractWalk end
struct TrainableStructWalk <: AbstractWalk end

(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
Expand Down Expand Up @@ -174,3 +174,4 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end

1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ and `trainable(x)` must contain a subset of these.
"""
trainable(x) = functor(x)[1]

# like trainable(x), but also tries to output non-trainable children giving value nothing
_trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
Expand Down
59 changes: 59 additions & 0 deletions src/trainables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

"""
trainables(x)

Return a list over all the trainable parameters in `x`, that is all the numerical
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).

Parameters appearing multiple times in the model (tied weights) will be present only once in the output.

See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.

# Examples

```jldoctest
julia> struct MyLayer
w
b
end

julia> Functors.@functor MyLayer

julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example

julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);

julia> trainables(x)
1-element Vector{AbstractArray}:
[1.0, 2.0, 3.0]

julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);

julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
"""
function trainables(x)
arrays = AbstractArray[]
exclude(x) = Optimisers.isnumeric(x)
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end

function ∇trainables(x, Δ)
exclude(x) = Optimisers.isnumeric(x)
i = 0
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end

function ChainRulesCore.rrule(::typeof(trainables), x)
y = trainables(x)
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
return y, trainables_back
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Random.seed!(1)

struct Foo; x; y; end
Functors.@functor Foo
Optimisers.trainable(x::Foo) = (x.y, x.x)
Optimisers.trainable(x::Foo) = (; x.y, x.x)

struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Expand Down Expand Up @@ -539,6 +539,9 @@ end
@testset verbose=true "Destructure" begin
include("destructure.jl")
end
@testset verbose=true "Trainables" begin
include("trainables.jl")
end
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
Expand Down
115 changes: 115 additions & 0 deletions test/trainables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

m1 = collect(1:3.0)
m2 = (collect(1:3.0), collect(4:6.0))
m3 = (x = m1, y = sin, z = collect(4:6.0))

m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
m6 = (a = m1, b = [4.0 + im], c = m1)

m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]

mat = Float32[4 6; 5 7]
m9 = (a = m1, b = mat, c = [mat, m1])

@testset "trainables" begin
ps = trainables(m1)
@test ps isa Vector
@test length(ps) == 1
@test ps[1] == m1

ps = trainables(m2)
@test ps isa Vector
@test length(ps) == 2
@test ps[1] == m2[1]
@test ps[2] == m2[2]

ps = trainables(m3)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == 4:6

ps = trainables(m4)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == 4:6

ps = trainables(m5)
@test length(ps) == 3
@test ps[1] == 1:3
@test ps[2] == 4:6
@test ps[3] == 4:6

ps = trainables(m6)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == ComplexF64[4.0 + 1.0im]

ps = trainables(m7)
@test length(ps) == 1
@test ps[1] == [1.0, 2.0, 3.0]

ps = trainables(m8)
@test length(ps) == 3
@test ps[1] == 1:3
@test ps[2] == [4.0]
@test ps[3] == [5.0]

ps = trainables(m9)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == mat
end

@testset "gradient" begin
loss(m) = sum([sum(abs2, p) for p in trainables(m)])
g = gradient(loss, m1)[1]
@test g == [2.0, 4.0, 6.0]

g = gradient(loss, m2)[1]
@test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0])

g = gradient(loss, m3)[1]
@test g.x == [2.0, 4.0, 6.0]
@test g.y === nothing
@test g.z == [8.0, 10.0, 12.0]

g = gradient(loss, m4)[1]
@test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0])
g.x === g.y # shared gradient for shared weights

g = gradient(loss, m5)[1]
@test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing))

g = gradient(loss, m6)[1]
@test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0])

g = gradient(loss, m7)[1]
@test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing)

g = gradient(loss, m8)[1]
@test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0])
@test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing)
@test g[3] == [[10.0]]

g = gradient(loss, m9)[1]
@test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]])
end

@testset "second order derivatives" begin
struct DenseLayer
w
b
end

Functors.@functor DenseLayer

loss(m) = sum([sum(abs2, p) for p in trainables(m)])

model = DenseLayer([1. 2.; 3. 4.], [0., 0.])

g = gradient(m -> loss(gradient(loss, m)), model)[1]
@test g.w == [8.0 16.0; 24.0 32.0]
@test g.b == [0.0, 0.0]
end