1
- export grad, jacobian, jvp, j′vp, to_vec
2
- replace_arg (x, xs:: Tuple , k:: Int ) = ntuple (p -> p == k ? x : xs[p], length (xs))
3
-
4
1
"""
5
- grad (fdm, f, xs ...)
2
+ jacobian (fdm, f, x ...)
6
3
7
- Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)` is scalar.
4
+ Approximate the Jacobian of `f` at `x` using `fdm`. Results will be returned as a
5
+ `Matrix{<:Number}` of `size(length(y_vec), length(x_vec))` where `x_vec` is the flattened
6
+ version of `x`, and `y_vec` the flattened version of `f(x...)`. Flattening performed by
7
+ [`to_vec`](@ref).
8
8
"""
9
- function grad end
10
-
11
- function _grad (fdm, f, x:: AbstractArray{T} ) where T <: Number
12
- # x must be mutable, we will mutate it and then mutate it back.
13
- dx = similar (x)
14
- for k in eachindex (x)
15
- dx[k] = fdm (zero (T)) do ϵ
16
- xk = x[k]
17
- x[k] = xk + ϵ
18
- ret = f (x)
19
- x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
9
+ function jacobian (fdm, f, x:: Vector{<:Number} ; len= nothing )
10
+ len != = nothing && Base. depwarn (
11
+ " `len` keyword argument to `jacobian` is no longer required " *
12
+ " and will not be permitted in the future." ,
13
+ :jacobian
14
+ )
15
+ ẏs = map (eachindex (x)) do n
16
+ return fdm (zero (eltype (x))) do ε
17
+ xn = x[n]
18
+ x[n] = xn + ε
19
+ ret = first (to_vec (f (x)))
20
+ x[n] = xn # Can't do `x[n] -= ϵ` as floating-point math is not associative
20
21
return ret
21
22
end
22
23
end
23
- return (dx , )
24
+ return (hcat (ẏs ... ) , )
24
25
end
25
26
26
- grad (fdm, f, x:: Array{<:Number} ) = _grad (fdm, f, x)
27
- # Fallback for when we don't know `x` will be mutable:
28
- grad (fdm, f, x:: AbstractArray{<:Number} ) = _grad (fdm, f, similar (x).= x)
29
-
30
- grad (fdm, f, x:: Real ) = (fdm (f, x), )
31
- grad (fdm, f, x:: Tuple ) = (grad (fdm, (xs... )-> f (xs), x... ), )
32
-
33
- function grad (fdm, f, d:: Dict{K, V} ) where {K, V}
34
- ∇d = Dict {K, V} ()
35
- for (k, v) in d
36
- dk = d[k]
37
- function f′ (x)
38
- d[k] = x
39
- return f (d)
40
- end
41
- ∇d[k] = grad (fdm, f′, v)[1 ]
42
- d[k] = dk
43
- end
44
- return (∇d, )
27
+ function jacobian (fdm, f, x; len= nothing )
28
+ x_vec, from_vec = to_vec (x)
29
+ return jacobian (fdm, f ∘ from_vec, x_vec; len= len)
45
30
end
46
31
47
- function grad (fdm, f, x)
48
- v, back = to_vec (x)
49
- return (back (grad (fdm, x-> f (back (v)), v)), )
50
- end
51
-
52
- function grad (fdm, f, xs... )
53
- return ntuple (length (xs)) do k
54
- grad (fdm, x-> f (replace_arg (x, xs, k)... ), xs[k])[1 ]
55
- end
56
- end
57
-
58
- """
59
- jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}; len::Int=length(f(x)))
60
-
61
- Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `len` vector. If
62
- `len` is not provided, then `f(x)` is computed once to determine the output size.
63
- """
64
- function jacobian (fdm, f, x:: Union{T, AbstractArray{T}} ; len:: Int = length (f (x))) where {T <: Number }
65
- J = Matrix {float(T)} (undef, len, length (x))
66
- for d in 1 : len
67
- gs = grad (fdm, x-> f (x)[d], x)[1 ]
68
- for k in 1 : length (x)
69
- J[d, k] = gs[k]
70
- end
71
- end
72
- return (J, )
73
- end
74
-
75
- function jacobian (fdm, f, xs... ; len:: Int = length (f (xs... )))
32
+ function jacobian (fdm, f, xs... ; len= nothing )
76
33
return ntuple (length (xs)) do k
77
34
jacobian (fdm, x-> f (replace_arg (x, xs, k)... ), xs[k]; len= len)[1 ]
78
35
end
79
36
end
80
37
38
+ replace_arg (x, xs:: Tuple , k:: Int ) = ntuple (p -> p == k ? x : xs[p], length (xs))
39
+
81
40
"""
82
41
_jvp(fdm, f, x::Vector{<:Number}, ẋ::AbstractVector{<:Number})
83
42
84
43
Convenience function to compute `jacobian(f, x) * ẋ`.
85
44
"""
86
- _jvp (fdm, f, x:: Vector{<:Number} , ẋ:: AV{<:Number} ) = fdm (ε -> f (x .+ ε .* ẋ), zero (eltype (x)))
87
-
88
- """
89
- _j′vp(fdm, f, ȳ::AbstractVector{<:Number}, x::Vector{<:Number})
90
-
91
- Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
92
- """
93
- _j′vp (fdm, f, ȳ:: AV{<:Number} , x:: Vector{<:Number} ) = transpose (jacobian (fdm, f, x; len= length (ȳ))[1 ]) * ȳ
45
+ function _jvp (fdm, f, x:: Vector{<:Number} , ẋ:: Vector{<:Number} )
46
+ return fdm (ε -> f (x .+ ε .* ẋ), zero (eltype (x)))
47
+ end
94
48
95
49
"""
96
50
jvp(fdm, f, x, ẋ)
@@ -111,110 +65,23 @@ end
111
65
"""
112
66
j′vp(fdm, f, ȳ, x...)
113
67
114
- Compute an adjoint with any types of arguments for which [`to_vec`](@ref) is defined.
68
+ Compute an adjoint with any types of arguments `x` for which [`to_vec`](@ref) is defined.
115
69
"""
116
70
function j′vp (fdm, f, ȳ, x)
117
71
x_vec, vec_to_x = to_vec (x)
118
72
ȳ_vec, _ = to_vec (ȳ)
119
- return (vec_to_x (_j′vp (fdm, x_vec -> to_vec ( f ( vec_to_x (x_vec)))[ 1 ] , ȳ_vec, x_vec)), )
73
+ return (vec_to_x (_j′vp (fdm, first ∘ to_vec ∘ f ∘ vec_to_x, ȳ_vec, x_vec)), )
120
74
end
121
- j′vp (fdm, f, ȳ, xs... ) = j′vp (fdm, xs-> f (xs... ), ȳ, xs)[1 ]
122
75
123
- """
124
- to_vec(x) -> Tuple{<:AbstractVector, <:Function}
125
-
126
- Transform `x` into a `Vector`, and return a closure which inverts the transformation.
127
- """
128
- function to_vec (x:: Number )
129
- function Number_from_vec (x_vec)
130
- return first (x_vec)
131
- end
132
- return [x], Number_from_vec
133
- end
134
-
135
- # Vectors
136
- to_vec (x:: Vector{<:Number} ) = (x, identity)
137
- function to_vec (x:: Vector )
138
- x_vecs_and_backs = map (to_vec, x)
139
- x_vecs, backs = first .(x_vecs_and_backs), last .(x_vecs_and_backs)
140
- function Vector_from_vec (x_vec)
141
- sz = cumsum ([map (length, x_vecs)... ])
142
- return [backs[n](x_vec[sz[n]- length (x_vecs[n])+ 1 : sz[n]]) for n in eachindex (x)]
143
- end
144
- return vcat (x_vecs... ), Vector_from_vec
145
- end
146
-
147
- # Arrays
148
- function to_vec (x:: Array{<:Number} )
149
- function Array_from_vec (x_vec)
150
- return reshape (x_vec, size (x))
151
- end
152
- return vec (x), Array_from_vec
153
- end
154
-
155
- function to_vec (x:: Array )
156
- x_vec, back = to_vec (reshape (x, :))
157
- function Array_from_vec (x_vec)
158
- return reshape (back (x_vec), size (x))
159
- end
160
- return x_vec, Array_from_vec
161
- end
162
-
163
- # AbstractArrays
164
- function to_vec (x:: T ) where {T<: LinearAlgebra.AbstractTriangular }
165
- x_vec, back = to_vec (Matrix (x))
166
- function AbstractTriangular_from_vec (x_vec)
167
- return T (reshape (back (x_vec), size (x)))
168
- end
169
- return x_vec, AbstractTriangular_from_vec
170
- end
171
-
172
- function to_vec (x:: Symmetric )
173
- function Symmetric_from_vec (x_vec)
174
- return Symmetric (reshape (x_vec, size (x)))
175
- end
176
- return vec (Matrix (x)), Symmetric_from_vec
177
- end
178
-
179
- function to_vec (X:: Diagonal )
180
- function Diagonal_from_vec (x_vec)
181
- return Diagonal (reshape (x_vec, size (X)... ))
182
- end
183
- return vec (Matrix (X)), Diagonal_from_vec
184
- end
185
-
186
- function to_vec (X:: Transpose )
187
- function Transpose_from_vec (x_vec)
188
- return Transpose (permutedims (reshape (x_vec, size (X))))
189
- end
190
- return vec (Matrix (X)), Transpose_from_vec
191
- end
76
+ j′vp (fdm, f, ȳ, xs... ) = j′vp (fdm, xs-> f (xs... ), ȳ, xs)[1 ]
192
77
193
- function to_vec (X:: Adjoint )
194
- function Adjoint_from_vec (x_vec)
195
- return Adjoint (conj! (permutedims (reshape (x_vec, size (X)))))
196
- end
197
- return vec (Matrix (X)), Adjoint_from_vec
78
+ function _j′vp (fdm, f, ȳ:: Vector{<:Number} , x:: Vector{<:Number} )
79
+ return transpose (first (jacobian (fdm, f, x))) * ȳ
198
80
end
199
81
200
- # Non-array data structures
201
-
202
- function to_vec (x:: Tuple )
203
- x_vecs, x_backs = zip (map (to_vec, x)... )
204
- sz = cumsum ([map (length, x_vecs)... ])
205
- function Tuple_from_vec (v)
206
- return ntuple (n-> x_backs[n](v[sz[n]- length (x_vecs[n])+ 1 : sz[n]]), length (x))
207
- end
208
- return vcat (x_vecs... ), Tuple_from_vec
209
- end
82
+ """
83
+ grad(fdm, f, xs...)
210
84
211
- # Convert to a vector-of-vectors to make use of existing functionality.
212
- function to_vec (d:: Dict )
213
- d_vec_vec = [val for val in values (d)]
214
- d_vec, back = to_vec (d_vec_vec)
215
- function Dict_from_vec (v)
216
- v_vec_vec = back (v)
217
- return Dict ([(key, v_vec_vec[n]) for (n, key) in enumerate (keys (d))])
218
- end
219
- return d_vec, Dict_from_vec
220
- end
85
+ Compute the gradient of `f` for any `xs` for which [`to_vec`](@ref) is defined.
86
+ """
87
+ grad (fdm, f, xs... ) = j′vp (fdm, f, 1 , xs... ) # `j′vp` with seed of 1
0 commit comments