From 43e51a2f45783bf08ae5615224c99c3197096b70 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 30 Mar 2024 16:16:09 +0100 Subject: [PATCH 1/9] trainables --- src/destructure.jl | 3 ++- src/interface.jl | 1 + src/trainables.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 src/trainables.jl diff --git a/src/destructure.jl b/src/destructure.jl index a73e36a6..1a1d28b7 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -73,7 +73,7 @@ function _flatten(x) o end isempty(arrays) && return Bool[], off, 0 - reduce(vcat, arrays), off, len[] + return reduce(vcat, arrays), off, len[] end struct _TrainableStructWalk <: AbstractWalk end @@ -174,3 +174,4 @@ function ChainRulesCore.rrule(::typeof(_maybewarn)) @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3 nothing, _ -> (NoT,) end + diff --git a/src/interface.jl b/src/interface.jl index 29c1db60..aa5447c0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -167,6 +167,7 @@ and `trainable(x)` must contain a subset of these. """ trainable(x) = functor(x)[1] +# like trainable(x), but also tries to output non-trainable children giving value nothing _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) _trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr diff --git a/src/trainables.jl b/src/trainables.jl new file mode 100644 index 00000000..f23685bd --- /dev/null +++ b/src/trainables.jl @@ -0,0 +1,35 @@ + + +function trainables1(x) + isnumeric(x) && return [x] + arrays = AbstractArray[] + fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y + push!(arrays, y) + return y + end + return arrays +end + +############ + +using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk + +struct TrainableWalk2 <: AbstractWalk end + +function (walk::TrainableWalk2)(recurse, x, ys...) + x_children = _values(Optimisers.trainable(x)) + ys_children = map(Optimisers.trainable, ys) + res = _map(recurse, x_children, ys_children...) + @show _values(res) + return _values(res) +end + +function trainables2(x) + exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x) + return execute(ExcludeWalk(TrainableWalk2(), x -> x, exclude), x) +end + +using Flux + +m = Chain(Dense(2 => 3, relu), BatchNorm(3), Dense(3 => 2)) +trainables2(m) \ No newline at end of file From a6e40a0900e0c3cb00eae99869ad9a207f829b08 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 1 Apr 2024 09:40:54 +0200 Subject: [PATCH 2/9] trainables --- src/trainables.jl | 66 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/src/trainables.jl b/src/trainables.jl index f23685bd..9b5b2c86 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -1,9 +1,13 @@ - +using BenchmarkTools +using Optimisers +using Functors +using Zygote, Flux function trainables1(x) - isnumeric(x) && return [x] + Optimisers.isnumeric(x) && return [x] arrays = AbstractArray[] - fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y + exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x) + fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y push!(arrays, y) return y end @@ -17,19 +21,61 @@ using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk struct TrainableWalk2 <: AbstractWalk end function (walk::TrainableWalk2)(recurse, x, ys...) - x_children = _values(Optimisers.trainable(x)) + x_children = Optimisers.trainable(x) ys_children = map(Optimisers.trainable, ys) - res = _map(recurse, x_children, ys_children...) - @show _values(res) - return _values(res) + res = map(recurse, x_children, ys_children...) + return reduce(vcat, values(res),init=[]) end function trainables2(x) exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x) - return execute(ExcludeWalk(TrainableWalk2(), x -> x, exclude), x) + return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x) +end + + +struct TrainableWalk3 <: AbstractWalk end + +function (walk::TrainableWalk3)(recurse, x, ys...) + x_children = Optimisers.trainable(x) + ys_children = map(Optimisers.trainable, ys) + res = map(recurse, x_children, ys_children...) + return vcat(values(res)...) +end + +function trainables3(x) + exclude(x) = Optimisers.isnumeric(x) + return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x) +end + + +function floss(ps) + sum([sum(p) for p in ps]) end using Flux -m = Chain(Dense(2 => 3, relu), BatchNorm(3), Dense(3 => 2)) -trainables2(m) \ No newline at end of file +function perf() + m = Chain(Dense(128 => 128, relu), + Dense(128 => 128, relu), + BatchNorm(128), Dense(3 => 2), x -> x^2) + Dense(128 => 128, relu), + Dense(128 => 128, relu) + + println("trainables1") + @btime trainables1($m) + println("trainables2") + @btime trainables2($m) + println("trainables3") + @btime trainables3($m) + println() + + + # gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating + println("gradient trainables2") + @btime gradient(m -> floss(trainables2(m)), $m) + println("gradient trainables3") + @btime gradient(m -> floss(trainables3(m)), $m) +end + +Zygote.refresh() +perf() \ No newline at end of file From 53e06783396b86b813a39a1fd00e95843dee236f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 1 Apr 2024 11:09:04 +0200 Subject: [PATCH 3/9] cl/trainables --- src/trainables.jl | 56 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/src/trainables.jl b/src/trainables.jl index 9b5b2c86..fcf940bb 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -2,11 +2,11 @@ using BenchmarkTools using Optimisers using Functors using Zygote, Flux +using ChainRulesCore function trainables1(x) - Optimisers.isnumeric(x) && return [x] arrays = AbstractArray[] - exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x) + exclude(x) = Optimisers.isnumeric(x) fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y push!(arrays, y) return y @@ -14,6 +14,21 @@ function trainables1(x) return arrays end +function ∇trainables1(x, Δ) + exclude(x) = Optimisers.isnumeric(x) + i = 0 + return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _ + return Δ[i+=1] + end +end + + +function ChainRulesCore.rrule(::typeof(trainables1), x) + y = trainables1(x) + trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ))) + return y, trainables_back +end + ############ using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk @@ -49,7 +64,7 @@ end function floss(ps) - sum([sum(p) for p in ps]) + sum([sum(abs2, p) for p in ps]) end using Flux @@ -57,25 +72,44 @@ using Flux function perf() m = Chain(Dense(128 => 128, relu), Dense(128 => 128, relu), - BatchNorm(128), Dense(3 => 2), x -> x^2) + BatchNorm(128), + x -> x^2, Dense(128 => 128, relu), - Dense(128 => 128, relu) + Dense(128 => 128, relu)) println("trainables1") - @btime trainables1($m) + @btime floss(trainables1($m)) println("trainables2") - @btime trainables2($m) + @btime floss(trainables2($m)) println("trainables3") - @btime trainables3($m) + @btime floss(trainables3($m)) println() - - # gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating + println("gradient trainables1") + @btime gradient(m -> floss(trainables1(m)), $m) println("gradient trainables2") @btime gradient(m -> floss(trainables2(m)), $m) println("gradient trainables3") @btime gradient(m -> floss(trainables3(m)), $m) + + nothing end Zygote.refresh() -perf() \ No newline at end of file +perf() + + +m = Chain(Dense(128 => 128, relu), + Dense(128 => 128, relu), + BatchNorm(128), + x -> x^2, + Dense(128 => 128, relu), + Dense(128 => 128, relu)) + +floss(trainables1(m)) +g1 = gradient(m -> floss(trainables1(m)), m)[1] +g2 = gradient(m -> floss(trainables2(m)), m)[1] +@test g1.layers[1].weight ≈ g2.layers[1].weight +@test g1.layers[1].weight ≈ g2.layers[1].weight +@test g1.layers[3].μ === nothing +@test g2.layers[3].μ === nothing From 9131192b65b23e442ae09f176899d2a1942b4480 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 Apr 2024 10:31:15 +0200 Subject: [PATCH 4/9] trainables --- docs/src/api.md | 1 + src/Optimisers.jl | 3 + src/destructure.jl | 6 +- src/trainables.jl | 133 ++++++++++++--------------------------------- test/runtests.jl | 5 +- test/trainables.jl | 99 +++++++++++++++++++++++++++++++++ 6 files changed, 145 insertions(+), 102 deletions(-) create mode 100644 test/trainables.jl diff --git a/docs/src/api.md b/docs/src/api.md index 661f83bc..f66884e0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -58,6 +58,7 @@ Such restrictions are also obeyed by this function for flattening a model: ```@docs Optimisers.destructure Optimisers.Restructure +Optimisers.trainables ``` ## Rule Definition diff --git a/src/Optimisers.jl b/src/Optimisers.jl index efe28697..3cc98808 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -11,6 +11,9 @@ include("adjust.jl") include("destructure.jl") export destructure +include("trainables.jl") +export trainables + include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, diff --git a/src/destructure.jl b/src/destructure.jl index 1a1d28b7..f9950a92 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -66,7 +66,7 @@ function _flatten(x) isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case arrays = AbstractVector[] len = Ref(0) - off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y + off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y push!(arrays, _vec(y)) o = len[] len[] = o + length(y) @@ -76,9 +76,9 @@ function _flatten(x) return reduce(vcat, arrays), off, len[] end -struct _TrainableStructWalk <: AbstractWalk end +struct TrainableStructWalk <: AbstractWalk end -(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x)) +(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x)) _vec(x::Number) = LinRange(x,x,1) _vec(x::AbstractArray) = vec(x) diff --git a/src/trainables.jl b/src/trainables.jl index fcf940bb..151e0aee 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -1,115 +1,52 @@ -using BenchmarkTools -using Optimisers -using Functors -using Zygote, Flux -using ChainRulesCore -function trainables1(x) +""" + trainables(x) + +Return an iterable over all the trainable parameters in `x`, that is all the numerical +arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable). + +Parameters appearing multiple times in the model will be present only once in the output. + +See also [`destructure`](@ref). + +# Examples + +```jldoctest +julia> struct MyLayer + w + b + end + +julia> Functors.@functor MyLayer + +julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example + +julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]); + +julia> trainables(x) +1-element Vector{AbstractArray}: + [1.0, 2.0, 3.0] +""" +function trainables(x) arrays = AbstractArray[] exclude(x) = Optimisers.isnumeric(x) - fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y + fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y push!(arrays, y) return y end return arrays end -function ∇trainables1(x, Δ) +function ∇trainables(x, Δ) exclude(x) = Optimisers.isnumeric(x) i = 0 - return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _ + return fmapstructure(x; exclude, walk = Optimisers.TrainableStructWalk()) do _ return Δ[i+=1] end end - -function ChainRulesCore.rrule(::typeof(trainables1), x) - y = trainables1(x) - trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ))) +function ChainRulesCore.rrule(::typeof(trainables), x) + y = trainables(x) + trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ))) return y, trainables_back end - -############ - -using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk - -struct TrainableWalk2 <: AbstractWalk end - -function (walk::TrainableWalk2)(recurse, x, ys...) - x_children = Optimisers.trainable(x) - ys_children = map(Optimisers.trainable, ys) - res = map(recurse, x_children, ys_children...) - return reduce(vcat, values(res),init=[]) -end - -function trainables2(x) - exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x) - return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x) -end - - -struct TrainableWalk3 <: AbstractWalk end - -function (walk::TrainableWalk3)(recurse, x, ys...) - x_children = Optimisers.trainable(x) - ys_children = map(Optimisers.trainable, ys) - res = map(recurse, x_children, ys_children...) - return vcat(values(res)...) -end - -function trainables3(x) - exclude(x) = Optimisers.isnumeric(x) - return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x) -end - - -function floss(ps) - sum([sum(abs2, p) for p in ps]) -end - -using Flux - -function perf() - m = Chain(Dense(128 => 128, relu), - Dense(128 => 128, relu), - BatchNorm(128), - x -> x^2, - Dense(128 => 128, relu), - Dense(128 => 128, relu)) - - println("trainables1") - @btime floss(trainables1($m)) - println("trainables2") - @btime floss(trainables2($m)) - println("trainables3") - @btime floss(trainables3($m)) - println() - - println("gradient trainables1") - @btime gradient(m -> floss(trainables1(m)), $m) - println("gradient trainables2") - @btime gradient(m -> floss(trainables2(m)), $m) - println("gradient trainables3") - @btime gradient(m -> floss(trainables3(m)), $m) - - nothing -end - -Zygote.refresh() -perf() - - -m = Chain(Dense(128 => 128, relu), - Dense(128 => 128, relu), - BatchNorm(128), - x -> x^2, - Dense(128 => 128, relu), - Dense(128 => 128, relu)) - -floss(trainables1(m)) -g1 = gradient(m -> floss(trainables1(m)), m)[1] -g2 = gradient(m -> floss(trainables2(m)), m)[1] -@test g1.layers[1].weight ≈ g2.layers[1].weight -@test g1.layers[1].weight ≈ g2.layers[1].weight -@test g1.layers[3].μ === nothing -@test g2.layers[3].μ === nothing diff --git a/test/runtests.jl b/test/runtests.jl index e4a14016..fc0fe57f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ Random.seed!(1) struct Foo; x; y; end Functors.@functor Foo -Optimisers.trainable(x::Foo) = (x.y, x.x) +Optimisers.trainable(x::Foo) = (; x.y, x.x) struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) @@ -539,6 +539,9 @@ end @testset verbose=true "Destructure" begin include("destructure.jl") end + @testset verbose=true "Trainables" begin + include("trainables.jl") + end @testset verbose=true "Optimisation Rules" begin include("rules.jl") end diff --git a/test/trainables.jl b/test/trainables.jl new file mode 100644 index 00000000..b11fd7ec --- /dev/null +++ b/test/trainables.jl @@ -0,0 +1,99 @@ + +m1 = collect(1:3.0) +m2 = (collect(1:3.0), collect(4:6.0)) +m3 = (x = m1, y = sin, z = collect(4:6.0)) + +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied +m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) +m6 = (a = m1, b = [4.0 + im], c = m1) + +m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) +m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] + +mat = Float32[4 6; 5 7] +m9 = (a = m1, b = mat, c = [mat, m1]) + +@testset "trainables" begin + ps = trainables(m1) + @test ps isa Vector + @test length(ps) == 1 + @test ps[1] == m1 + + ps = trainables(m2) + @test ps isa Vector + @test length(ps) == 2 + @test ps[1] == m2[1] + @test ps[2] == m2[2] + + ps = trainables(m3) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + + ps = trainables(m4) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + + ps = trainables(m5) + @test length(ps) == 3 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + @test ps[3] == 4:6 + + ps = trainables(m6) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == ComplexF64[4.0 + 1.0im] + + ps = trainables(m7) + @test length(ps) == 1 + @test ps[1] == [1.0, 2.0, 3.0] + + ps = trainables(m8) + @test length(ps) == 3 + @test ps[1] == 1:3 + @test ps[2] == [4.0] + @test ps[3] == [5.0] + + ps = trainables(m9) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == mat +end + +@testset "gradient" begin + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + g = gradient(loss, m1)[1] + @test g == [2.0, 4.0, 6.0] + + g = gradient(loss, m2)[1] + @test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0]) + + g = gradient(loss, m3)[1] + @test g.x == [2.0, 4.0, 6.0] + @test g.y === nothing + @test g.z == [8.0, 10.0, 12.0] + + g = gradient(loss, m4)[1] + @test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]) + g.x === g.y # shared gradient for shared weights + + g = gradient(loss, m5)[1] + @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) + + g = gradient(loss, m6)[1] + @test g = (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) + + g = gradient(loss, m7)[1] + @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) + + g = gradient(loss, m8)[1] + @test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0]) + @test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing) + @test g[3] == [[10.0]] + + g = gradient(loss, m9)[1] + @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) +end + From 06f786cdf8e96faa4d6f0093b7facd6f21aab8ff Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 Apr 2024 13:48:21 +0200 Subject: [PATCH 5/9] test second order derivatives --- src/trainables.jl | 13 ++++++++++--- test/trainables.jl | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/trainables.jl b/src/trainables.jl index 151e0aee..c4abe6b9 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -5,9 +5,9 @@ Return an iterable over all the trainable parameters in `x`, that is all the numerical arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable). -Parameters appearing multiple times in the model will be present only once in the output. +Parameters appearing multiple times in the model (tied weights) will be present only once in the output. -See also [`destructure`](@ref). +See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead. # Examples @@ -26,6 +26,13 @@ julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]); julia> trainables(x) 1-element Vector{AbstractArray}: [1.0, 2.0, 3.0] + + julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]); + + julia> trainables(x) # collects nested parameters + 2-element Vector{AbstractArray}: + [1.0, 2.0] + [3.0] """ function trainables(x) arrays = AbstractArray[] @@ -40,7 +47,7 @@ end function ∇trainables(x, Δ) exclude(x) = Optimisers.isnumeric(x) i = 0 - return fmapstructure(x; exclude, walk = Optimisers.TrainableStructWalk()) do _ + return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _ return Δ[i+=1] end end diff --git a/test/trainables.jl b/test/trainables.jl index b11fd7ec..5d9131cf 100644 --- a/test/trainables.jl +++ b/test/trainables.jl @@ -97,3 +97,19 @@ end @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) end +@testset "second order derivatives" begin + struct DenseLayer + w + b + end + + Functors.@functor DenseLayer + + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + + model = DenseLayer([1. 2.; 3. 4.], [0., 0.]) + + g = gradient(m -> loss(gradient(loss, m)), model)[1] + @test g.w == [8.0 16.0; 24.0 32.0] + @test g.b == [0.0, 0.0] +end From faea10f80857a410d2d3b6c88e21ca86a79c00a0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 Apr 2024 14:02:36 +0200 Subject: [PATCH 6/9] add doc section --- docs/src/index.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 38d7b93e..04809461 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -290,3 +290,29 @@ flat, re = destructure(params) end ``` +## Collecting all trainable parameters + +Sometimes it is useful to collect all trainable parameters in a model, +similarly to what [`destructure`](@ref Optimisers.destructure) does but keeping +the arrays separate. +This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays: + +```julia +julia> using Flux, Optimisers + +julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2)); + +julia> trainables(model) +6-element Vector{AbstractArray}: + Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061] + Float32[0.0, 0.0, 0.0] + Float32[0.0, 0.0, 0.0] + Float32[1.0, 1.0, 1.0] + Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252] + Float32[0.0, 0.0] + +julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]); + +julia> g = gradient(l2reg, model)[1]; +``` +Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not. From 292a82d7900af7bcc61e8038e2669e1e380ad039 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 Apr 2024 14:07:06 +0200 Subject: [PATCH 7/9] fix test --- test/trainables.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/trainables.jl b/test/trainables.jl index 5d9131cf..d4b93ce8 100644 --- a/test/trainables.jl +++ b/test/trainables.jl @@ -83,7 +83,7 @@ end @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) g = gradient(loss, m6)[1] - @test g = (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) + @test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) g = gradient(loss, m7)[1] @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) From 2694e9ea6c7793af14fcdb57107809b2330e2901 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 Apr 2024 15:24:39 +0200 Subject: [PATCH 8/9] Update src/trainables.jl Co-authored-by: Kyle Daruwalla --- src/trainables.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trainables.jl b/src/trainables.jl index c4abe6b9..625c5659 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -2,7 +2,7 @@ """ trainables(x) -Return an iterable over all the trainable parameters in `x`, that is all the numerical +Return a list over all the trainable parameters in `x`, that is all the numerical arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable). Parameters appearing multiple times in the model (tied weights) will be present only once in the output. From c42823d35d67dda6c80d41da0f5d34236e98ad4a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 Apr 2024 15:24:49 +0200 Subject: [PATCH 9/9] Update docs/src/index.md Co-authored-by: Kyle Daruwalla --- docs/src/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 04809461..30ef5c45 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -293,8 +293,8 @@ end ## Collecting all trainable parameters Sometimes it is useful to collect all trainable parameters in a model, -similarly to what [`destructure`](@ref Optimisers.destructure) does but keeping -the arrays separate. +similarly to what [`destructure`](@ref Optimisers.destructure) does but without +concatenating the arrays into a flat vector. This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays: ```julia