Skip to content

Commit 4195db5

Browse files
add path=true
1 parent a87ffd5 commit 4195db5

File tree

7 files changed

+224
-104
lines changed

7 files changed

+224
-104
lines changed

docs/src/api.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,14 @@ Optimisers.@lazy
7272
Optimisers.adjust(::AbstractRule, ::Real)
7373
Optimisers.@def
7474
```
75+
76+
## KeyPath
77+
78+
A `KeyPath` is a sequence of keys that can be used to access a value within a nested structure.
79+
It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.
80+
81+
```@docs
82+
Functors.KeyPath
83+
Functors.haskeypath
84+
Functors.getkeypath
85+
```

src/Optimisers.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module Optimisers
22

3-
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
3+
using Functors: functor, fmap, fmap_with_path,
4+
KeyPath, haskeypath, getkeypath,
5+
isleaf, @functor, fmapstructure, children, AbstractWalk
46
using LinearAlgebra
57

68
include("interface.jl")
@@ -13,6 +15,7 @@ export destructure
1315

1416
include("trainables.jl")
1517
export trainables
18+
export KeyPath, haskeypath, getkeypath # from Functors.jl
1619

1720
include("rules.jl")
1821
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,

src/destructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
struct TrainableStructWalk <: AbstractWalk end
8080

81-
(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
81+
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))
8282

8383
_vec(x::Number) = LinRange(x,x,1)
8484
_vec(x::AbstractArray) = vec(x)

src/interface.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function _setup(rule, x; cache)
4545
cache[x] =
4646
end
4747
else
48-
valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
48+
mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
4949
end
5050
end
5151

@@ -82,7 +82,7 @@ function _update!(tree, x; grads, params)
8282
haskey(params, (tree,x)) && return params[(tree,x)]
8383
isbits(tree) && return x # means () is not cached, and also (((),),)
8484
x′, re = functor(x)
85-
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
85+
x′′ = re(mapvalue((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
8686
if ismutable(x′′)
8787
params[(tree,x)] = x′′
8888
else # no ties to preserve between immutable structs, right?
@@ -115,7 +115,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
115115
# functor(typeof(tree), base(x̄)), for things like Transpose
116116
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
117117
x′, _ = functor(typeof(x), x)
118-
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
118+
foreachvalue((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
119119
end
120120

121121
# default all rules to first order calls
@@ -172,16 +172,16 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x))
172172
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
173173
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
174174
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
175-
_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr)
175+
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)
176176

177177
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
178178
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
179179
map(c -> c in tr ? c : nothing, ch)
180180
end
181181

182182

183-
valuemap(f, x...) = map(f, x...)
184-
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
183+
mapvalue(f, x...) = map(f, x...)
184+
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
185185
valueforeach(f, x...) = foreach(f, x...)
186186
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
187187
f(v, (get(y, k, nothing) for y in ys)...)

src/trainables.jl

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11

22
"""
3-
trainables(x)
3+
trainables(x, path = false)
44
5-
Return a list over all the trainable parameters in `x`, that is all the numerical
5+
Return an iterable over all the trainable parameters in `x`, that is all the numerical
66
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).
77
88
Parameters appearing multiple times in the model (tied weights) will be present only once in the output.
99
10+
If `path = false`, the output is a list of numerical arrays.
11+
12+
If `path = true`, the output is a list of `(KeyPath, AbstractArray)` pairs, where [`KeyPath`](@ref Functors.KeyPath) is a type
13+
representing the path to the array in the original structure.
14+
1015
See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.
1116
1217
# Examples
@@ -33,27 +38,87 @@ julia> trainables(x)
3338
2-element Vector{AbstractArray}:
3439
[1.0, 2.0]
3540
[3.0]
41+
```
42+
43+
```jldoctest
44+
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));
45+
46+
julia> for (kp, y) in trainables(x, path = true)
47+
println(kp, " => ", y)
48+
end
49+
KeyPath(:a,) => [1.0, 2.0]
50+
KeyPath(:b, 1, "c") => [3.0, 4.0]
51+
KeyPath(:b, 2) => [6.0, 7.0]
52+
53+
julia> getkeypath(x, KeyPath(:b, 1, "c"))
54+
2-element Vector{Float64}:
55+
3.0
56+
4.0
57+
```
3658
"""
37-
function trainables(x)
59+
function trainables(x; path = false)
60+
if path
61+
return _trainables_with_path(x)
62+
else
63+
return _trainables(x)
64+
end
65+
end
66+
67+
68+
function _trainables(x)
3869
arrays = AbstractArray[]
39-
exclude(x) = Optimisers.isnumeric(x)
40-
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
70+
fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
4171
push!(arrays, y)
4272
return y
4373
end
4474
return arrays
4575
end
4676

4777
function ∇trainables(x, Δ)
48-
exclude(x) = Optimisers.isnumeric(x)
4978
i = 0
50-
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
79+
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
5180
return Δ[i+=1]
5281
end
5382
end
5483

55-
function ChainRulesCore.rrule(::typeof(trainables), x)
84+
function ChainRulesCore.rrule(::typeof(_trainables), x)
5685
y = trainables(x)
5786
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
5887
return y, trainables_back
5988
end
89+
90+
function _trainables_with_path(x)
91+
named_params = []
92+
exclude(kp, x) = isnumeric(x)
93+
fmap_with_path(x; exclude, walk = TrainableStructWalkWithPath()) do kp, y
94+
push!(named_params, (kp, y))
95+
return y
96+
end
97+
return named_params
98+
end
99+
100+
struct TrainableStructWalkWithPath <: AbstractWalk end
101+
102+
function (::TrainableStructWalkWithPath)(recurse, kp::KeyPath, x)
103+
x_children = trainable(x)
104+
kps = mapkey(c -> KeyPath(kp, c), x_children)
105+
return mapvalue(recurse, kps, x_children)
106+
end
107+
108+
function ChainRulesCore.rrule(::typeof(_trainables_with_path), x)
109+
y = _trainables_with_path(x)
110+
trainables_with_path_back(Δ) = (NoTangent(), ∇trainables_with_path(x, unthunk(Δ)))
111+
return y, trainables_with_path_back
112+
end
113+
114+
function ∇trainables_with_path(x, Δ)
115+
i = 0
116+
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
117+
Δi = Δ[i+=1]
118+
if isnothing(Δi)
119+
return nothing
120+
else
121+
return Δi[2]
122+
end
123+
end
124+
end

src/utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
mapvalue(f, x...) = map(f, x...)
3+
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
4+
5+
mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks))
6+
mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x))
7+
mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x))
8+
mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)]
9+
10+
valueforeach(f, x...) = foreach(f, x...)
11+
12+
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
13+
f(v, (get(y, k, nothing) for y in ys)...)
14+
end
15+

0 commit comments

Comments
 (0)