Skip to content

Commit 17e4a29

Browse files
authored
Merge pull request #26 from Roger-luo/roger/multiargs
add multi arg jacobian
2 parents 593ade3 + 66987ef commit 17e4a29

File tree

3 files changed

+117
-23
lines changed

3 files changed

+117
-23
lines changed

src/FiniteDifferences.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@ module FiniteDifferences
77
include("methods.jl")
88
include("numerics.jl")
99
include("grad.jl")
10+
11+
12+
@deprecate jacobian(fdm, f, x::Vector, D::Int) jacobian(fdm, f, x; len=D)
1013
end

src/grad.jl

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,75 @@
11
export grad, jacobian, jvp, j′vp, to_vec
2+
replace_arg(x, xs::Tuple, k::Int) = ntuple(p -> p == k ? x : xs[p], length(xs))
23

34
"""
4-
grad(fdm, f, x::Vector{<:Number})
5+
grad(fdm, f, xs...)
56
6-
Approximate the gradient of `f` at `x` using `fdm`. Assumes that `f(x)` is scalar.
7+
Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)` is scalar.
78
"""
8-
function grad(fdm, f, x::Vector{T}) where T<:Number
9-
v, dx, tmp = fill(zero(T), size(x)), similar(x), similar(x)
10-
for n in eachindex(x)
11-
v[n] = one(T)
12-
dx[n] = fdm(function)
13-
tmp .= x .+ ϵ .* v
14-
return f(tmp)
15-
end,
16-
zero(T),
17-
)
18-
v[n] = zero(T)
9+
function grad end
10+
11+
function grad(fdm, f, x::AbstractArray{T}) where T <: Number
12+
dx = similar(x)
13+
tmp = similar(x)
14+
for k in eachindex(x)
15+
dx[k] = fdm(zero(T)) do ϵ
16+
tmp .= x
17+
tmp[k] += ϵ
18+
return f(tmp)
19+
end
1920
end
2021
return dx
2122
end
2223

24+
grad(fdm, f, x::Real) = fdm(f, x)
25+
grad(fdm, f, x::Tuple) = grad(fdm, (xs...)->f(xs), x...)
26+
27+
function grad(fdm, f, d::Dict{K, V}) where {K, V}
28+
dd = Dict{K, V}()
29+
for (k, v) in d
30+
function f′(x)
31+
tmp = copy(d)
32+
tmp[k] = x
33+
return f(tmp)
34+
end
35+
dd[k] = grad(fdm, f′, v)
36+
end
37+
return dd
38+
end
39+
40+
function grad(fdm, f, x)
41+
v, back = to_vec(x)
42+
return back(grad(fdm, x->f(back(v)), v))
43+
end
44+
45+
function grad(fdm, f, xs...)
46+
return ntuple(length(xs)) do k
47+
grad(fdm, x->f(replace_arg(x, xs, k)...), xs[k])
48+
end
49+
end
50+
2351
"""
24-
jacobian(fdm, f, x::Vector{<:Number}, D::Int)
25-
jacobian(fdm, f, x::Vector{<:Number})
52+
jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}; len::Int=length(f(x)))
2653
2754
Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `D` vector. If
2855
`D` is not provided, then `f(x)` is computed once to determine the output size.
2956
"""
30-
function jacobian(fdm, f, x::Vector{T}, D::Int) where {T<:Number}
31-
J = Matrix{T}(undef, D, length(x))
32-
for d in 1:D
33-
J[d, :] = grad(fdm, x->f(x)[d], x)
57+
function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; len::Int=length(f(x))) where {T <: Number}
58+
J = Matrix{float(T)}(undef, len, length(x))
59+
for d in 1:len
60+
gs = grad(fdm, x->f(x)[d], x)
61+
for k in 1:length(x)
62+
J[d, k] = gs[k]
63+
end
3464
end
3565
return J
3666
end
37-
jacobian(fdm, f, x::Vector{<:Number}) = jacobian(fdm, f, x, length(f(x)))
67+
68+
function jacobian(fdm, f, xs...; len::Int=length(f(xs...)))
69+
return ntuple(length(xs)) do k
70+
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; len=len)
71+
end
72+
end
3873

3974
"""
4075
_jvp(fdm, f, x::Vector{<:Number}, ẋ::AbstractVector{<:Number})
@@ -48,7 +83,7 @@ _jvp(fdm, f, x::Vector{<:Number}, ẋ::AV{<:Number}) = fdm(ε -> f(x .+ ε .* x
4883
4984
Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
5085
"""
51-
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x, length(ȳ))) *
86+
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x; len=length(ȳ))) *
5287

5388
"""
5489
jvp(fdm, f, x, ẋ)

test/grad.jl

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ Base.length(x::DummyType) = size(x.X, 1)
2626

2727
function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact)
2828
xc = copy(x)
29-
@test jacobian(fdm, f, x, length(ȳ)) J_exact
30-
@test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ))
29+
@test jacobian(fdm, f, x; len=length(ȳ)) J_exact
30+
@test jacobian(fdm, f, x) == jacobian(fdm, f, x; len=length(ȳ))
3131
@test _jvp(fdm, f, x, ẋ) J_exact *
3232
@test _j′vp(fdm, f, ȳ, x) transpose(J_exact) *
3333
@test xc == x
@@ -44,6 +44,62 @@ Base.length(x::DummyType) = size(x.X, 1)
4444
@test Ac == A
4545
end
4646

47+
@testset "multi vars jacobian/grad" begin
48+
rng, fdm = MersenneTwister(123456), central_fdm(5, 1)
49+
50+
f1(x, y) = x * y + x
51+
f2(x, y) = sum(x * y + x)
52+
f3(x::Tuple) = sum(x[1]) + x[2]
53+
f4(d::Dict) = sum(d[:x]) + d[:y]
54+
55+
@testset "jacobian" begin
56+
@testset "check multiple matrices" begin
57+
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
58+
jac_xs = jacobian(fdm, f1, x, y)
59+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
60+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
61+
end
62+
63+
@testset "check mixed scalar and matrices" begin
64+
x, y = rand(3, 3), 2
65+
jac_xs = jacobian(fdm, f1, x, y)
66+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
67+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
68+
end
69+
end
70+
71+
@testset "grad" begin
72+
@testset "check multiple matrices" begin
73+
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
74+
dxs = grad(fdm, f2, x, y)
75+
@test dxs[1] grad(fdm, x->f2(x, y), x)
76+
@test dxs[2] grad(fdm, y->f2(x, y), y)
77+
end
78+
79+
@testset "check mixed scalar & matrices" begin
80+
x, y = rand(rng, 3, 3), 2
81+
dxs = grad(fdm, f2, x, y)
82+
@test dxs[1] grad(fdm, x->f2(x, y), x)
83+
@test dxs[2] grad(fdm, y->f2(x, y), y)
84+
end
85+
86+
@testset "check tuple" begin
87+
x, y = rand(rng, 3, 3), 2
88+
dxs = grad(fdm, f3, (x, y))
89+
@test dxs[1] grad(fdm, x->f3((x, y)), x)
90+
@test dxs[2] grad(fdm, y->f3((x, y)), y)
91+
end
92+
93+
@testset "check dict" begin
94+
x, y = rand(rng, 3, 3), 2
95+
d = Dict(:x=>x, :y=>y)
96+
dxs = grad(fdm, f4, d)
97+
@test dxs[:x] grad(fdm, x->f3((x, y)), x)
98+
@test dxs[:y] grad(fdm, y->f3((x, y)), y)
99+
end
100+
end
101+
end
102+
47103
function test_to_vec(x)
48104
x_vec, back = to_vec(x)
49105
@test x_vec isa Vector

0 commit comments

Comments
 (0)