Skip to content

Commit 86f6207

Browse files
willtebbuttoxinaboxnickrobinson251wesselb
authored
Refactor jacobian and grad (#66)
* Implements Jacobian via jvp * Improves perf of jacobian * Refactors grad * Refactors jacobian * Tidies a few things up * Update src/grad.jl Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Changes composition for application. Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Clarifies jacobian algorithm Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Clarifies grad function Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Fixes to_vec docstring. Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Update src/to_vec.jl Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Moves exports to main file * Fixes jpvp docstring * Tidies up to_vec(::Vector) * Tidies up to_vec(::Dict) * Tidies up to_vec(::Dict) further * Tidies up to_vec(::Tuple) further * Deprecates len parameter * Tweaks jacobian comment * Tidies up docstring * Tweaks depwarn * Tweaks jacobian deprecation Co-Authored-By: Nick Robinson <npr251@gmail.com> * Fixes typo in docs Co-Authored-By: Wessel <wessel.p.bruinsma@gmail.com> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Nick Robinson <npr251@gmail.com> Co-authored-by: Wessel <wessel.p.bruinsma@gmail.com>
1 parent 74614e2 commit 86f6207

File tree

7 files changed

+238
-260
lines changed

7 files changed

+238
-260
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.9.2"
3+
version = "0.9.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FiniteDifferences.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ module FiniteDifferences
22

33
using Printf, LinearAlgebra
44

5-
const AV = AbstractVector
5+
export to_vec, grad, jacobian, jvp, j′vp
66

77
include("methods.jl")
88
include("numerics.jl")
9+
include("to_vec.jl")
910
include("grad.jl")
10-
11-
12-
@deprecate jacobian(fdm, f, x::Vector, D::Int) jacobian(fdm, f, x; len=D)
1311
end

src/grad.jl

Lines changed: 37 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,50 @@
1-
export grad, jacobian, jvp, j′vp, to_vec
2-
replace_arg(x, xs::Tuple, k::Int) = ntuple(p -> p == k ? x : xs[p], length(xs))
3-
41
"""
5-
grad(fdm, f, xs...)
2+
jacobian(fdm, f, x...)
63
7-
Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)` is scalar.
4+
Approximate the Jacobian of `f` at `x` using `fdm`. Results will be returned as a
5+
`Matrix{<:Number}` of `size(length(y_vec), length(x_vec))` where `x_vec` is the flattened
6+
version of `x`, and `y_vec` the flattened version of `f(x...)`. Flattening performed by
7+
[`to_vec`](@ref).
88
"""
9-
function grad end
10-
11-
function _grad(fdm, f, x::AbstractArray{T}) where T <: Number
12-
# x must be mutable, we will mutate it and then mutate it back.
13-
dx = similar(x)
14-
for k in eachindex(x)
15-
dx[k] = fdm(zero(T)) do ϵ
16-
xk = x[k]
17-
x[k] = xk + ϵ
18-
ret = f(x)
19-
x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
9+
function jacobian(fdm, f, x::Vector{<:Number}; len=nothing)
10+
len !== nothing && Base.depwarn(
11+
"`len` keyword argument to `jacobian` is no longer required " *
12+
"and will not be permitted in the future.",
13+
:jacobian
14+
)
15+
ẏs = map(eachindex(x)) do n
16+
return fdm(zero(eltype(x))) do ε
17+
xn = x[n]
18+
x[n] = xn + ε
19+
ret = first(to_vec(f(x)))
20+
x[n] = xn # Can't do `x[n] -= ϵ` as floating-point math is not associative
2021
return ret
2122
end
2223
end
23-
return (dx, )
24+
return (hcat(ẏs...), )
2425
end
2526

26-
grad(fdm, f, x::Array{<:Number}) = _grad(fdm, f, x)
27-
# Fallback for when we don't know `x` will be mutable:
28-
grad(fdm, f, x::AbstractArray{<:Number}) = _grad(fdm, f, similar(x).=x)
29-
30-
grad(fdm, f, x::Real) = (fdm(f, x), )
31-
grad(fdm, f, x::Tuple) = (grad(fdm, (xs...)->f(xs), x...), )
32-
33-
function grad(fdm, f, d::Dict{K, V}) where {K, V}
34-
∇d = Dict{K, V}()
35-
for (k, v) in d
36-
dk = d[k]
37-
function f′(x)
38-
d[k] = x
39-
return f(d)
40-
end
41-
∇d[k] = grad(fdm, f′, v)[1]
42-
d[k] = dk
43-
end
44-
return (∇d, )
27+
function jacobian(fdm, f, x; len=nothing)
28+
x_vec, from_vec = to_vec(x)
29+
return jacobian(fdm, f from_vec, x_vec; len=len)
4530
end
4631

47-
function grad(fdm, f, x)
48-
v, back = to_vec(x)
49-
return (back(grad(fdm, x->f(back(v)), v)), )
50-
end
51-
52-
function grad(fdm, f, xs...)
53-
return ntuple(length(xs)) do k
54-
grad(fdm, x->f(replace_arg(x, xs, k)...), xs[k])[1]
55-
end
56-
end
57-
58-
"""
59-
jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}; len::Int=length(f(x)))
60-
61-
Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `len` vector. If
62-
`len` is not provided, then `f(x)` is computed once to determine the output size.
63-
"""
64-
function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; len::Int=length(f(x))) where {T <: Number}
65-
J = Matrix{float(T)}(undef, len, length(x))
66-
for d in 1:len
67-
gs = grad(fdm, x->f(x)[d], x)[1]
68-
for k in 1:length(x)
69-
J[d, k] = gs[k]
70-
end
71-
end
72-
return (J, )
73-
end
74-
75-
function jacobian(fdm, f, xs...; len::Int=length(f(xs...)))
32+
function jacobian(fdm, f, xs...; len=nothing)
7633
return ntuple(length(xs)) do k
7734
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; len=len)[1]
7835
end
7936
end
8037

38+
replace_arg(x, xs::Tuple, k::Int) = ntuple(p -> p == k ? x : xs[p], length(xs))
39+
8140
"""
8241
_jvp(fdm, f, x::Vector{<:Number}, ẋ::AbstractVector{<:Number})
8342
8443
Convenience function to compute `jacobian(f, x) * ẋ`.
8544
"""
86-
_jvp(fdm, f, x::Vector{<:Number}, ẋ::AV{<:Number}) = fdm-> f(x .+ ε .* ẋ), zero(eltype(x)))
87-
88-
"""
89-
_j′vp(fdm, f, ȳ::AbstractVector{<:Number}, x::Vector{<:Number})
90-
91-
Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
92-
"""
93-
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x; len=length(ȳ))[1]) *
45+
function _jvp(fdm, f, x::Vector{<:Number}, ẋ::Vector{<:Number})
46+
return fdm-> f(x .+ ε .* ẋ), zero(eltype(x)))
47+
end
9448

9549
"""
9650
jvp(fdm, f, x, ẋ)
@@ -111,110 +65,23 @@ end
11165
"""
11266
j′vp(fdm, f, ȳ, x...)
11367
114-
Compute an adjoint with any types of arguments for which [`to_vec`](@ref) is defined.
68+
Compute an adjoint with any types of arguments `x` for which [`to_vec`](@ref) is defined.
11569
"""
11670
function j′vp(fdm, f, ȳ, x)
11771
x_vec, vec_to_x = to_vec(x)
11872
ȳ_vec, _ = to_vec(ȳ)
119-
return (vec_to_x(_j′vp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], ȳ_vec, x_vec)), )
73+
return (vec_to_x(_j′vp(fdm, first to_vec f vec_to_x, ȳ_vec, x_vec)), )
12074
end
121-
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1]
12275

123-
"""
124-
to_vec(x) -> Tuple{<:AbstractVector, <:Function}
125-
126-
Transform `x` into a `Vector`, and return a closure which inverts the transformation.
127-
"""
128-
function to_vec(x::Number)
129-
function Number_from_vec(x_vec)
130-
return first(x_vec)
131-
end
132-
return [x], Number_from_vec
133-
end
134-
135-
# Vectors
136-
to_vec(x::Vector{<:Number}) = (x, identity)
137-
function to_vec(x::Vector)
138-
x_vecs_and_backs = map(to_vec, x)
139-
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
140-
function Vector_from_vec(x_vec)
141-
sz = cumsum([map(length, x_vecs)...])
142-
return [backs[n](x_vec[sz[n]-length(x_vecs[n])+1:sz[n]]) for n in eachindex(x)]
143-
end
144-
return vcat(x_vecs...), Vector_from_vec
145-
end
146-
147-
# Arrays
148-
function to_vec(x::Array{<:Number})
149-
function Array_from_vec(x_vec)
150-
return reshape(x_vec, size(x))
151-
end
152-
return vec(x), Array_from_vec
153-
end
154-
155-
function to_vec(x::Array)
156-
x_vec, back = to_vec(reshape(x, :))
157-
function Array_from_vec(x_vec)
158-
return reshape(back(x_vec), size(x))
159-
end
160-
return x_vec, Array_from_vec
161-
end
162-
163-
# AbstractArrays
164-
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
165-
x_vec, back = to_vec(Matrix(x))
166-
function AbstractTriangular_from_vec(x_vec)
167-
return T(reshape(back(x_vec), size(x)))
168-
end
169-
return x_vec, AbstractTriangular_from_vec
170-
end
171-
172-
function to_vec(x::Symmetric)
173-
function Symmetric_from_vec(x_vec)
174-
return Symmetric(reshape(x_vec, size(x)))
175-
end
176-
return vec(Matrix(x)), Symmetric_from_vec
177-
end
178-
179-
function to_vec(X::Diagonal)
180-
function Diagonal_from_vec(x_vec)
181-
return Diagonal(reshape(x_vec, size(X)...))
182-
end
183-
return vec(Matrix(X)), Diagonal_from_vec
184-
end
185-
186-
function to_vec(X::Transpose)
187-
function Transpose_from_vec(x_vec)
188-
return Transpose(permutedims(reshape(x_vec, size(X))))
189-
end
190-
return vec(Matrix(X)), Transpose_from_vec
191-
end
76+
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1]
19277

193-
function to_vec(X::Adjoint)
194-
function Adjoint_from_vec(x_vec)
195-
return Adjoint(conj!(permutedims(reshape(x_vec, size(X)))))
196-
end
197-
return vec(Matrix(X)), Adjoint_from_vec
78+
function _j′vp(fdm, f, ȳ::Vector{<:Number}, x::Vector{<:Number})
79+
return transpose(first(jacobian(fdm, f, x))) *
19880
end
19981

200-
# Non-array data structures
201-
202-
function to_vec(x::Tuple)
203-
x_vecs, x_backs = zip(map(to_vec, x)...)
204-
sz = cumsum([map(length, x_vecs)...])
205-
function Tuple_from_vec(v)
206-
return ntuple(n->x_backs[n](v[sz[n]-length(x_vecs[n])+1:sz[n]]), length(x))
207-
end
208-
return vcat(x_vecs...), Tuple_from_vec
209-
end
82+
"""
83+
grad(fdm, f, xs...)
21084
211-
# Convert to a vector-of-vectors to make use of existing functionality.
212-
function to_vec(d::Dict)
213-
d_vec_vec = [val for val in values(d)]
214-
d_vec, back = to_vec(d_vec_vec)
215-
function Dict_from_vec(v)
216-
v_vec_vec = back(v)
217-
return Dict([(key, v_vec_vec[n]) for (n, key) in enumerate(keys(d))])
218-
end
219-
return d_vec, Dict_from_vec
220-
end
85+
Compute the gradient of `f` for any `xs` for which [`to_vec`](@ref) is defined.
86+
"""
87+
grad(fdm, f, xs...) = j′vp(fdm, f, 1, xs...) # `j′vp` with seed of 1

src/to_vec.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
to_vec(x)
3+
4+
Transform `x` into a `Vector`, and return the vector, and a closure which inverts the
5+
transformation.
6+
"""
7+
function to_vec(x::Number)
8+
function Number_from_vec(x_vec)
9+
return first(x_vec)
10+
end
11+
return [x], Number_from_vec
12+
end
13+
14+
# Vectors
15+
to_vec(x::Vector{<:Number}) = (x, identity)
16+
function to_vec(x::Vector)
17+
x_vecs_and_backs = map(to_vec, x)
18+
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
19+
function Vector_from_vec(x_vec)
20+
sz = cumsum(map(length, x_vecs))
21+
return [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
22+
end
23+
return vcat(x_vecs...), Vector_from_vec
24+
end
25+
26+
# Arrays
27+
function to_vec(x::Array{<:Number})
28+
function Array_from_vec(x_vec)
29+
return reshape(x_vec, size(x))
30+
end
31+
return vec(x), Array_from_vec
32+
end
33+
34+
function to_vec(x::Array)
35+
x_vec, back = to_vec(reshape(x, :))
36+
function Array_from_vec(x_vec)
37+
return reshape(back(x_vec), size(x))
38+
end
39+
return x_vec, Array_from_vec
40+
end
41+
42+
# AbstractArrays
43+
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
44+
x_vec, back = to_vec(Matrix(x))
45+
function AbstractTriangular_from_vec(x_vec)
46+
return T(reshape(back(x_vec), size(x)))
47+
end
48+
return x_vec, AbstractTriangular_from_vec
49+
end
50+
51+
function to_vec(x::Symmetric)
52+
function Symmetric_from_vec(x_vec)
53+
return Symmetric(reshape(x_vec, size(x)))
54+
end
55+
return vec(Matrix(x)), Symmetric_from_vec
56+
end
57+
58+
function to_vec(X::Diagonal)
59+
function Diagonal_from_vec(x_vec)
60+
return Diagonal(reshape(x_vec, size(X)...))
61+
end
62+
return vec(Matrix(X)), Diagonal_from_vec
63+
end
64+
65+
function to_vec(X::Transpose)
66+
function Transpose_from_vec(x_vec)
67+
return Transpose(permutedims(reshape(x_vec, size(X))))
68+
end
69+
return vec(Matrix(X)), Transpose_from_vec
70+
end
71+
72+
function to_vec(X::Adjoint)
73+
function Adjoint_from_vec(x_vec)
74+
return Adjoint(conj!(permutedims(reshape(x_vec, size(X)))))
75+
end
76+
return vec(Matrix(X)), Adjoint_from_vec
77+
end
78+
79+
# Non-array data structures
80+
81+
function to_vec(x::Tuple)
82+
x_vecs, x_backs = zip(map(to_vec, x)...)
83+
sz = cumsum(collect(map(length, x_vecs)))
84+
function Tuple_from_vec(v)
85+
return ntuple(length(x)) do n
86+
return x_backs[n](v[sz[n] - length(x_vecs[n]) + 1:sz[n]])
87+
end
88+
end
89+
return reduce(vcat, x_vecs), Tuple_from_vec
90+
end
91+
92+
# Convert to a vector-of-vectors to make use of existing functionality.
93+
function to_vec(d::Dict)
94+
d_vec, back = to_vec(collect(values(d)))
95+
function Dict_from_vec(v)
96+
v_vec_vec = back(v)
97+
return Dict(key => v_vec_vec[n] for (n, key) in enumerate(keys(d)))
98+
end
99+
return d_vec, Dict_from_vec
100+
end

0 commit comments

Comments
 (0)