From b4920f7ae745651147e12d7ec169c4871d02034e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 Apr 2024 16:20:05 +0200 Subject: [PATCH 1/5] fix broken documentation (#172) * fix docs * cleanup --- .gitignore | 2 ++ docs/.DS_Store | Bin 0 -> 6148 bytes docs/make.jl | 2 +- docs/src/api.md | 2 ++ 4 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 docs/.DS_Store diff --git a/.gitignore b/.gitignore index e4fecbbe..952f7cea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ Manifest.toml .vscode/ +docs/build/ +.DS_Store diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e14553107e0e9f4dfeb62498c2852eba47071535 GIT binary patch literal 6148 zcmeHK&2G~`5T4Bs-B5vY0JXj41KofO4E}}4VMBy|WhKUXrV%hL_GC=RH4k1h+h6IZ5?^`%b;%vOJaVaX*TK&$7T#>7C zKRD9UVB8x|vUYE@N251dC*e)M7ru@A({A&@OP!8;aoQi5ggEMB$h+M*jr6pwCux+K z+`u|yLpHk2wb|_XvzGF<<}Eezo^Eco)RyPX=MDMr@sk&?yC0H4sy|yE1%7Zw?m4`H zFW4zOP;%%62YN8HvERRzy#w9pxX)F7+&_P`es!_^b42u;h~ZITh5N5)p~Cvd!6;31 zdWcrqaTPKc!9i(P;*+}d6nAuTepX+zJ92pgeY=}QZ)qpu6ene22ABb6U^xTs2$ieL z-E3tBn1TP80XiQPDxvQ%w`h(IY)lD&SVOZB>{Bg4Im)5$Ft>;v6k$^lZK`l3hOp^q zS1!(Xm|L{zAYAbw+{(gLD8j6c^D7e$!nepRGr$a#8K~M}lkWeWU)TTTBpxvX%)mdz zfT(qXP8%P|-K`4`M|Z74eUC~)ak<4GDcCVrF~-tWypC!F?TR{xzQf!iT2S~$z|g=A JGw??lxCZy@gAM=y literal 0 HcmV?d00001 diff --git a/docs/make.jl b/docs/make.jl index 8d0504b2..47d64137 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,7 +2,7 @@ using Documenter, Optimisers, Zygote, StaticArrays, Functors DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true) -makedocs(modules = [Optimisers, Zygote, StaticArrays, Functors], +makedocs(modules = [Optimisers], doctest = false, sitename = "Optimisers.jl", pages = ["Home" => "index.md", diff --git a/docs/src/api.md b/docs/src/api.md index 661f83bc..5648d167 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -51,6 +51,7 @@ To further restrict this by ignoring some fields of a layer type, define `traina ```@docs Optimisers.trainable Optimisers.isnumeric +Optimisers.maywrite ``` Such restrictions are also obeyed by this function for flattening a model: @@ -68,4 +69,5 @@ Optimisers.init Optimisers.@.. Optimisers.@lazy Optimisers.adjust(::AbstractRule, ::Real) +Optimisers.@def ``` From a87ffd57d2b8fedb4f367d7ae34fd226ab3aa95e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 Apr 2024 15:32:44 +0200 Subject: [PATCH 2/5] add `trainables` (#171) * trainables * trainables * cl/trainables * trainables * test second order derivatives * add doc section * fix test * Update src/trainables.jl --- docs/src/api.md | 1 + docs/src/index.md | 26 ++++++++++ src/Optimisers.jl | 3 ++ src/destructure.jl | 9 ++-- src/interface.jl | 1 + src/trainables.jl | 59 +++++++++++++++++++++++ test/runtests.jl | 5 +- test/trainables.jl | 115 +++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 214 insertions(+), 5 deletions(-) create mode 100644 src/trainables.jl create mode 100644 test/trainables.jl diff --git a/docs/src/api.md b/docs/src/api.md index 5648d167..5c203492 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -59,6 +59,7 @@ Such restrictions are also obeyed by this function for flattening a model: ```@docs Optimisers.destructure Optimisers.Restructure +Optimisers.trainables ``` ## Rule Definition diff --git a/docs/src/index.md b/docs/src/index.md index 38d7b93e..30ef5c45 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -290,3 +290,29 @@ flat, re = destructure(params) end ``` +## Collecting all trainable parameters + +Sometimes it is useful to collect all trainable parameters in a model, +similarly to what [`destructure`](@ref Optimisers.destructure) does but without +concatenating the arrays into a flat vector. +This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays: + +```julia +julia> using Flux, Optimisers + +julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2)); + +julia> trainables(model) +6-element Vector{AbstractArray}: + Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061] + Float32[0.0, 0.0, 0.0] + Float32[0.0, 0.0, 0.0] + Float32[1.0, 1.0, 1.0] + Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252] + Float32[0.0, 0.0] + +julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]); + +julia> g = gradient(l2reg, model)[1]; +``` +Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not. diff --git a/src/Optimisers.jl b/src/Optimisers.jl index efe28697..3cc98808 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -11,6 +11,9 @@ include("adjust.jl") include("destructure.jl") export destructure +include("trainables.jl") +export trainables + include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, diff --git a/src/destructure.jl b/src/destructure.jl index a73e36a6..f9950a92 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -66,19 +66,19 @@ function _flatten(x) isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case arrays = AbstractVector[] len = Ref(0) - off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y + off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y push!(arrays, _vec(y)) o = len[] len[] = o + length(y) o end isempty(arrays) && return Bool[], off, 0 - reduce(vcat, arrays), off, len[] + return reduce(vcat, arrays), off, len[] end -struct _TrainableStructWalk <: AbstractWalk end +struct TrainableStructWalk <: AbstractWalk end -(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x)) +(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x)) _vec(x::Number) = LinRange(x,x,1) _vec(x::AbstractArray) = vec(x) @@ -174,3 +174,4 @@ function ChainRulesCore.rrule(::typeof(_maybewarn)) @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3 nothing, _ -> (NoT,) end + diff --git a/src/interface.jl b/src/interface.jl index 29c1db60..aa5447c0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -167,6 +167,7 @@ and `trainable(x)` must contain a subset of these. """ trainable(x) = functor(x)[1] +# like trainable(x), but also tries to output non-trainable children giving value nothing _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) _trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr diff --git a/src/trainables.jl b/src/trainables.jl new file mode 100644 index 00000000..625c5659 --- /dev/null +++ b/src/trainables.jl @@ -0,0 +1,59 @@ + +""" + trainables(x) + +Return a list over all the trainable parameters in `x`, that is all the numerical +arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable). + +Parameters appearing multiple times in the model (tied weights) will be present only once in the output. + +See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead. + +# Examples + +```jldoctest +julia> struct MyLayer + w + b + end + +julia> Functors.@functor MyLayer + +julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example + +julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]); + +julia> trainables(x) +1-element Vector{AbstractArray}: + [1.0, 2.0, 3.0] + + julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]); + + julia> trainables(x) # collects nested parameters + 2-element Vector{AbstractArray}: + [1.0, 2.0] + [3.0] +""" +function trainables(x) + arrays = AbstractArray[] + exclude(x) = Optimisers.isnumeric(x) + fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y + push!(arrays, y) + return y + end + return arrays +end + +function ∇trainables(x, Δ) + exclude(x) = Optimisers.isnumeric(x) + i = 0 + return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _ + return Δ[i+=1] + end +end + +function ChainRulesCore.rrule(::typeof(trainables), x) + y = trainables(x) + trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ))) + return y, trainables_back +end diff --git a/test/runtests.jl b/test/runtests.jl index e4a14016..fc0fe57f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ Random.seed!(1) struct Foo; x; y; end Functors.@functor Foo -Optimisers.trainable(x::Foo) = (x.y, x.x) +Optimisers.trainable(x::Foo) = (; x.y, x.x) struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) @@ -539,6 +539,9 @@ end @testset verbose=true "Destructure" begin include("destructure.jl") end + @testset verbose=true "Trainables" begin + include("trainables.jl") + end @testset verbose=true "Optimisation Rules" begin include("rules.jl") end diff --git a/test/trainables.jl b/test/trainables.jl new file mode 100644 index 00000000..d4b93ce8 --- /dev/null +++ b/test/trainables.jl @@ -0,0 +1,115 @@ + +m1 = collect(1:3.0) +m2 = (collect(1:3.0), collect(4:6.0)) +m3 = (x = m1, y = sin, z = collect(4:6.0)) + +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied +m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) +m6 = (a = m1, b = [4.0 + im], c = m1) + +m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) +m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] + +mat = Float32[4 6; 5 7] +m9 = (a = m1, b = mat, c = [mat, m1]) + +@testset "trainables" begin + ps = trainables(m1) + @test ps isa Vector + @test length(ps) == 1 + @test ps[1] == m1 + + ps = trainables(m2) + @test ps isa Vector + @test length(ps) == 2 + @test ps[1] == m2[1] + @test ps[2] == m2[2] + + ps = trainables(m3) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + + ps = trainables(m4) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + + ps = trainables(m5) + @test length(ps) == 3 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + @test ps[3] == 4:6 + + ps = trainables(m6) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == ComplexF64[4.0 + 1.0im] + + ps = trainables(m7) + @test length(ps) == 1 + @test ps[1] == [1.0, 2.0, 3.0] + + ps = trainables(m8) + @test length(ps) == 3 + @test ps[1] == 1:3 + @test ps[2] == [4.0] + @test ps[3] == [5.0] + + ps = trainables(m9) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == mat +end + +@testset "gradient" begin + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + g = gradient(loss, m1)[1] + @test g == [2.0, 4.0, 6.0] + + g = gradient(loss, m2)[1] + @test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0]) + + g = gradient(loss, m3)[1] + @test g.x == [2.0, 4.0, 6.0] + @test g.y === nothing + @test g.z == [8.0, 10.0, 12.0] + + g = gradient(loss, m4)[1] + @test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]) + g.x === g.y # shared gradient for shared weights + + g = gradient(loss, m5)[1] + @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) + + g = gradient(loss, m6)[1] + @test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) + + g = gradient(loss, m7)[1] + @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) + + g = gradient(loss, m8)[1] + @test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0]) + @test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing) + @test g[3] == [[10.0]] + + g = gradient(loss, m9)[1] + @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) +end + +@testset "second order derivatives" begin + struct DenseLayer + w + b + end + + Functors.@functor DenseLayer + + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + + model = DenseLayer([1. 2.; 3. 4.], [0., 0.]) + + g = gradient(m -> loss(gradient(loss, m)), model)[1] + @test g.w == [8.0 16.0; 24.0 32.0] + @test g.b == [0.0, 0.0] +end From 4195db564b2a65e062b6a67e1823529e1308610e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 Apr 2024 16:37:02 +0200 Subject: [PATCH 3/5] add path=true --- docs/src/api.md | 11 +++ src/Optimisers.jl | 5 +- src/destructure.jl | 2 +- src/interface.jl | 12 +-- src/trainables.jl | 81 ++++++++++++++++-- src/utils.jl | 15 ++++ test/trainables.jl | 202 +++++++++++++++++++++++++-------------------- 7 files changed, 224 insertions(+), 104 deletions(-) create mode 100644 src/utils.jl diff --git a/docs/src/api.md b/docs/src/api.md index 5c203492..296057bf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -72,3 +72,14 @@ Optimisers.@lazy Optimisers.adjust(::AbstractRule, ::Real) Optimisers.@def ``` + +## KeyPath + +A `KeyPath` is a sequence of keys that can be used to access a value within a nested structure. +It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience. + +```@docs +Functors.KeyPath +Functors.haskeypath +Functors.getkeypath +``` diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 3cc98808..aab220dd 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -1,6 +1,8 @@ module Optimisers -using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk +using Functors: functor, fmap, fmap_with_path, + KeyPath, haskeypath, getkeypath, + isleaf, @functor, fmapstructure, children, AbstractWalk using LinearAlgebra include("interface.jl") @@ -13,6 +15,7 @@ export destructure include("trainables.jl") export trainables +export KeyPath, haskeypath, getkeypath # from Functors.jl include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, diff --git a/src/destructure.jl b/src/destructure.jl index f9950a92..a6284522 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -78,7 +78,7 @@ end struct TrainableStructWalk <: AbstractWalk end -(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x)) +(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x)) _vec(x::Number) = LinRange(x,x,1) _vec(x::AbstractArray) = vec(x) diff --git a/src/interface.jl b/src/interface.jl index aa5447c0..d7b6dac3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -45,7 +45,7 @@ function _setup(rule, x; cache) cache[x] = ℓ end else - valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) + mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) end end @@ -82,7 +82,7 @@ function _update!(tree, x; grads, params) haskey(params, (tree,x)) && return params[(tree,x)] isbits(tree) && return x # means () is not cached, and also (((),),) x′, re = functor(x) - x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)) + x′′ = re(mapvalue((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)) if ismutable(x′′) params[(tree,x)] = x′′ else # no ties to preserve between immutable structs, right? @@ -115,7 +115,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...) # functor(typeof(tree), base(x̄)), for things like Transpose x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) x′, _ = functor(typeof(x), x) - valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) + foreachvalue((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) end # default all rules to first order calls @@ -172,7 +172,7 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) _trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr _trainable(ch::AbstractArray, tr::AbstractArray) = tr -_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr) +_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr) function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3 @@ -180,8 +180,8 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu end -valuemap(f, x...) = map(f, x...) -valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) +mapvalue(f, x...) = map(f, x...) +mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) valueforeach(f, x...) = foreach(f, x...) valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) f(v, (get(y, k, nothing) for y in ys)...) diff --git a/src/trainables.jl b/src/trainables.jl index 625c5659..80059dd7 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -1,12 +1,17 @@ """ - trainables(x) + trainables(x, path = false) -Return a list over all the trainable parameters in `x`, that is all the numerical +Return an iterable over all the trainable parameters in `x`, that is all the numerical arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable). Parameters appearing multiple times in the model (tied weights) will be present only once in the output. +If `path = false`, the output is a list of numerical arrays. + +If `path = true`, the output is a list of `(KeyPath, AbstractArray)` pairs, where [`KeyPath`](@ref Functors.KeyPath) is a type +representing the path to the array in the original structure. + See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead. # Examples @@ -33,11 +38,36 @@ julia> trainables(x) 2-element Vector{AbstractArray}: [1.0, 2.0] [3.0] +``` + +```jldoctest +julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0])); + +julia> for (kp, y) in trainables(x, path = true) + println(kp, " => ", y) + end +KeyPath(:a,) => [1.0, 2.0] +KeyPath(:b, 1, "c") => [3.0, 4.0] +KeyPath(:b, 2) => [6.0, 7.0] + +julia> getkeypath(x, KeyPath(:b, 1, "c")) +2-element Vector{Float64}: + 3.0 + 4.0 +``` """ -function trainables(x) +function trainables(x; path = false) + if path + return _trainables_with_path(x) + else + return _trainables(x) + end +end + + +function _trainables(x) arrays = AbstractArray[] - exclude(x) = Optimisers.isnumeric(x) - fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y + fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y push!(arrays, y) return y end @@ -45,15 +75,50 @@ function trainables(x) end function ∇trainables(x, Δ) - exclude(x) = Optimisers.isnumeric(x) i = 0 - return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _ + return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _ return Δ[i+=1] end end -function ChainRulesCore.rrule(::typeof(trainables), x) +function ChainRulesCore.rrule(::typeof(_trainables), x) y = trainables(x) trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ))) return y, trainables_back end + +function _trainables_with_path(x) + named_params = [] + exclude(kp, x) = isnumeric(x) + fmap_with_path(x; exclude, walk = TrainableStructWalkWithPath()) do kp, y + push!(named_params, (kp, y)) + return y + end + return named_params +end + +struct TrainableStructWalkWithPath <: AbstractWalk end + +function (::TrainableStructWalkWithPath)(recurse, kp::KeyPath, x) + x_children = trainable(x) + kps = mapkey(c -> KeyPath(kp, c), x_children) + return mapvalue(recurse, kps, x_children) +end + +function ChainRulesCore.rrule(::typeof(_trainables_with_path), x) + y = _trainables_with_path(x) + trainables_with_path_back(Δ) = (NoTangent(), ∇trainables_with_path(x, unthunk(Δ))) + return y, trainables_with_path_back +end + +function ∇trainables_with_path(x, Δ) + i = 0 + return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _ + Δi = Δ[i+=1] + if isnothing(Δi) + return nothing + else + return Δi[2] + end + end +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..34fe0d60 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,15 @@ + +mapvalue(f, x...) = map(f, x...) +mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) + +mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks)) +mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x)) +mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x)) +mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)] + +valueforeach(f, x...) = foreach(f, x...) + +valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) + f(v, (get(y, k, nothing) for y in ys)...) +end + diff --git a/test/trainables.jl b/test/trainables.jl index d4b93ce8..e1aa0115 100644 --- a/test/trainables.jl +++ b/test/trainables.jl @@ -14,102 +14,128 @@ mat = Float32[4 6; 5 7] m9 = (a = m1, b = mat, c = [mat, m1]) @testset "trainables" begin - ps = trainables(m1) - @test ps isa Vector - @test length(ps) == 1 - @test ps[1] == m1 - - ps = trainables(m2) - @test ps isa Vector - @test length(ps) == 2 - @test ps[1] == m2[1] - @test ps[2] == m2[2] - - ps = trainables(m3) - @test length(ps) == 2 - @test ps[1] == 1:3 - @test ps[2] == 4:6 - - ps = trainables(m4) - @test length(ps) == 2 - @test ps[1] == 1:3 - @test ps[2] == 4:6 - - ps = trainables(m5) - @test length(ps) == 3 - @test ps[1] == 1:3 - @test ps[2] == 4:6 - @test ps[3] == 4:6 - - ps = trainables(m6) - @test length(ps) == 2 - @test ps[1] == 1:3 - @test ps[2] == ComplexF64[4.0 + 1.0im] - - ps = trainables(m7) - @test length(ps) == 1 - @test ps[1] == [1.0, 2.0, 3.0] - - ps = trainables(m8) - @test length(ps) == 3 - @test ps[1] == 1:3 - @test ps[2] == [4.0] - @test ps[3] == [5.0] - - ps = trainables(m9) - @test length(ps) == 2 - @test ps[1] == 1:3 - @test ps[2] == mat + ps = trainables(m1) + @test ps isa Vector + @test length(ps) == 1 + @test ps[1] == m1 + + ps = trainables(m2) + @test ps isa Vector + @test length(ps) == 2 + @test ps[1] == m2[1] + @test ps[2] == m2[2] + + ps = trainables(m3) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + + ps = trainables(m4) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + + ps = trainables(m5) + @test length(ps) == 3 + @test ps[1] == 1:3 + @test ps[2] == 4:6 + @test ps[3] == 4:6 + + ps = trainables(m6) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == ComplexF64[4.0 + 1.0im] + + ps = trainables(m7) + @test length(ps) == 1 + @test ps[1] == [1.0, 2.0, 3.0] + + ps = trainables(m8) + @test length(ps) == 3 + @test ps[1] == 1:3 + @test ps[2] == [4.0] + @test ps[3] == [5.0] + + ps = trainables(m9) + @test length(ps) == 2 + @test ps[1] == 1:3 + @test ps[2] == mat end @testset "gradient" begin - loss(m) = sum([sum(abs2, p) for p in trainables(m)]) - g = gradient(loss, m1)[1] - @test g == [2.0, 4.0, 6.0] - - g = gradient(loss, m2)[1] - @test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0]) - - g = gradient(loss, m3)[1] - @test g.x == [2.0, 4.0, 6.0] - @test g.y === nothing - @test g.z == [8.0, 10.0, 12.0] - - g = gradient(loss, m4)[1] - @test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]) - g.x === g.y # shared gradient for shared weights - - g = gradient(loss, m5)[1] - @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) - - g = gradient(loss, m6)[1] - @test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) - - g = gradient(loss, m7)[1] - @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) - - g = gradient(loss, m8)[1] - @test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0]) - @test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing) - @test g[3] == [[10.0]] - - g = gradient(loss, m9)[1] - @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + g = gradient(loss, m1)[1] + @test g == [2.0, 4.0, 6.0] + + g = gradient(loss, m2)[1] + @test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0]) + + g = gradient(loss, m3)[1] + @test g.x == [2.0, 4.0, 6.0] + @test g.y === nothing + @test g.z == [8.0, 10.0, 12.0] + + g = gradient(loss, m4)[1] + @test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]) + g.x === g.y # shared gradient for shared weights + + g = gradient(loss, m5)[1] + @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) + + g = gradient(loss, m6)[1] + @test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) + + g = gradient(loss, m7)[1] + @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) + + g = gradient(loss, m8)[1] + @test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0]) + @test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing) + @test g[3] == [[10.0]] + + g = gradient(loss, m9)[1] + @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) +end + +@testset "dict" begin + d = Dict(:a => rand(2), :b => ones(2)) + ps = trainables(d) + @test length(ps) == 2 + @test ps[1] == d[:a] + @test ps[2] == d[:b] + + g = gradient(d -> sum(trainables(d)[1].^2) /2 + sum(trainables(d)[2]), d)[1] + @test g[:a] == d[:a] + @test_broken g[:b] == [1.0, 1.0] end @testset "second order derivatives" begin - struct DenseLayer - w - b - end + struct DenseLayer + w + b + end + + Functors.@functor DenseLayer - Functors.@functor DenseLayer + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + + model = DenseLayer([1. 2.; 3. 4.], [0., 0.]) + + g = gradient(m -> loss(gradient(loss, m)), model)[1] + @test g.w == [8.0 16.0; 24.0 32.0] + @test g.b == [0.0, 0.0] +end - loss(m) = sum([sum(abs2, p) for p in trainables(m)]) +@testset "trainables(x, path=true)" begin + loss(m) = sum(abs2, trainables(m, path=true)[1][2]) - model = DenseLayer([1. 2.; 3. 4.], [0., 0.]) + ps = trainables(m4, path=true) + @test length(ps) == 2 + @test ps[1] == (KeyPath(:x,), [1.0, 2.0, 3.0]) + @test ps[2] == (KeyPath(:z,), [4.0, 5.0, 6.0]) - g = gradient(m -> loss(gradient(loss, m)), model)[1] - @test g.w == [8.0 16.0; 24.0 32.0] - @test g.b == [0.0, 0.0] + g = gradient(loss, m4)[1] + @test g.x == [2.0, 4.0, 6.0] + @test g.y == [2.0, 4.0, 6.0] + @test g.z === nothing end From 0d2cd4aed6b010c9a33509be202e01d90d1f4cf2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 Apr 2024 16:37:50 +0200 Subject: [PATCH 4/5] fix --- Project.toml | 2 +- src/Optimisers.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7422f28c..4920eaa9 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.4" +Functors = "0.4.9" Statistics = "1" Zygote = "0.6.40" julia = "1.6" diff --git a/src/Optimisers.jl b/src/Optimisers.jl index aab220dd..2e115c40 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -8,6 +8,8 @@ using LinearAlgebra include("interface.jl") export AbstractRule +include("utils.jl") + include("adjust.jl") include("destructure.jl") From 97c65bcd3bd3dcbd7af99663997521b1e1ddf191 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 Apr 2024 17:01:59 +0200 Subject: [PATCH 5/5] fix --- src/interface.jl | 8 -------- src/utils.jl | 4 ++-- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d7b6dac3..ac9b90bc 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -180,14 +180,6 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu end -mapvalue(f, x...) = map(f, x...) -mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) -valueforeach(f, x...) = foreach(f, x...) -valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) - f(v, (get(y, k, nothing) for y in ys)...) -end - - ### ### rule definition helpers ### diff --git a/src/utils.jl b/src/utils.jl index 34fe0d60..7c6c95be 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -7,9 +7,9 @@ mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x)) mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x)) mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)] -valueforeach(f, x...) = foreach(f, x...) +foreachvalue(f, x...) = foreach(f, x...) -valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) +foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) f(v, (get(y, k, nothing) for y in ys)...) end