Skip to content

Commit 5311de4

Browse files
committed
add total
1 parent c94abf6 commit 5311de4

File tree

4 files changed

+95
-7
lines changed

4 files changed

+95
-7
lines changed

docs/src/api.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ Optimisers.isnumeric
4747
Optimisers.trainable
4848
```
4949

50-
Such restrictions are also obeyed by this function for flattening a model:
50+
Such restrictions are also obeyed by this function for flattening a model,
51+
and one for applying a function to every parameter:
5152

5253
```@docs
5354
Optimisers.destructure
5455
Optimisers.Restructure
56+
Optimisers.total
5557
```
5658

5759
## Rule Definition

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export AbstractRule
99
include("adjust.jl")
1010

1111
include("destructure.jl")
12-
export destructure
12+
export destructure, total
1313

1414
include("rules.jl")
1515
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,

src/destructure.jl

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
3-
const NoT = NoTangent()
2+
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
3+
const NoT = ChainRulesCore.NoTangent()
44

55
"""
66
destructure(model) -> vector, reconstructor
@@ -116,9 +116,11 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
116116
y = _trainmap(f, ch, _trainable(x), au)
117117
y isa Tuple{} && return NoT
118118
p = ProjectTo(x)
119-
if p isa ProjectTo # e.g. Array, NamedTuple
120-
p(y)
121-
else # p === identity for unknown structs
119+
# if p isa ProjectTo # e.g. Array, NamedTuple
120+
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
121+
if x isa Union{Number, AbstractArray} # these don't use Tangent
122+
ProjectTo(x)(unthunk(y))
123+
else
122124
Tangent{typeof(x), typeof(y)}(y)
123125
end
124126
end
@@ -166,3 +168,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
166168
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
167169
nothing, _ -> (NoT,)
168170
end
171+
172+
"""
173+
total(f, model)
174+
175+
Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
176+
the model, and returns the sum. Differentiable. Counts shared weights once.
177+
178+
# Examples
179+
```jldoctest
180+
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
181+
182+
julia> total(sum, m)
183+
12.0
184+
185+
julia> total(norm, m)
186+
10.0
187+
188+
julia> total(length, m) == length(destructure(m)[1])
189+
true
190+
```
191+
"""
192+
function total(f, x)
193+
values = []
194+
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
195+
sum(values)
196+
end
197+
198+
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
199+
z, backs = _total_hobbit(config, f, x)
200+
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
201+
z, total_back
202+
end
203+
204+
function _total_hobbit(config::RuleConfig, f, x)
205+
values = []
206+
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
207+
val, back = rrule_via_ad(config, f, y)
208+
push!(values, val)
209+
back
210+
end
211+
sum(values), backs
212+
end
213+
214+
function _total_grad(dz, x, backs)
215+
dfs = []
216+
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
217+
df, dy = b(dz)
218+
push!(dfs, df)
219+
dy
220+
end
221+
sum(dfs), dx
222+
end
223+
224+
function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
225+
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
226+
function grad_back((df, dx))
227+
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
228+
(NoT, total(dx), NoT, NoT)
229+
end
230+
_total_grad(dz, x, backs), grad_back
231+
end

test/destructure.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,26 @@ tmp1
241241
y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4))
242242
@test bk(1.0) == (nothing,)
243243
end
244+
245+
@testset "total" begin
246+
@test total(sum, m1) == sum(1:3)
247+
@test total(prod, m2) == prod(1:3) + prod(4:6)
248+
@test total(sum, m3) == sum(1:6)
249+
@test total(sum, m4) == sum(1:6) # shared only counts once
250+
@test total(sum, m6) == 6 + 4 + im
251+
252+
@test gradient(m -> total(sum, m), m1) == ([1,1,1],)
253+
@test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
254+
@test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
255+
g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
256+
@test g6.a isa Vector{Float64}
257+
258+
@test gradient-> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
259+
@test gradient-> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)
260+
261+
@testset "second derivatives" begin
262+
f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
263+
@test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) Zygote.hessian_dual(f3, [1,2,3.0])
264+
# typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
265+
end
266+
end

0 commit comments

Comments
 (0)