Skip to content

Commit 9242bb2

Browse files
test second order derivatives
1 parent 9f47a66 commit 9242bb2

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

src/trainables.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
Return an iterable over all the trainable parameters in `x`, that is all the numerical
66
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).
77
8-
Parameters appearing multiple times in the model will be present only once in the output.
8+
Parameters appearing multiple times in the model (tied weights) will be present only once in the output.
99
10-
See also [`destructure`](@ref).
10+
See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.
1111
1212
# Examples
1313
@@ -26,6 +26,13 @@ julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);
2626
julia> trainables(x)
2727
1-element Vector{AbstractArray}:
2828
[1.0, 2.0, 3.0]
29+
30+
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
31+
32+
julia> trainables(x) # collects nested parameters
33+
2-element Vector{AbstractArray}:
34+
[1.0, 2.0]
35+
[3.0]
2936
"""
3037
function trainables(x)
3138
arrays = AbstractArray[]
@@ -40,7 +47,7 @@ end
4047
function ∇trainables(x, Δ)
4148
exclude(x) = Optimisers.isnumeric(x)
4249
i = 0
43-
return fmapstructure(x; exclude, walk = Optimisers.TrainableStructWalk()) do _
50+
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
4451
return Δ[i+=1]
4552
end
4653
end

test/trainables.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,19 @@ end
9797
@test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]])
9898
end
9999

100+
@testset "second order derivatives" begin
101+
struct DenseLayer
102+
w
103+
b
104+
end
105+
106+
Functors.@functor DenseLayer
107+
108+
loss(m) = sum([sum(abs2, p) for p in trainables(m)])
109+
110+
model = DenseLayer([1. 2.; 3. 4.], [0., 0.])
111+
112+
g = gradient(m -> loss(gradient(loss, m)), model)[1]
113+
@test g.w == [8.0 16.0; 24.0 32.0]
114+
@test g.b == [0.0, 0.0]
115+
end

0 commit comments

Comments
 (0)