Skip to content

Commit 8d6c8c8

Browse files
Rogerluowilltebbutt
authored andcommitted
change return type of grad/jacobian to match args (#51)
* change return type * finish this * revert jvp * Update Project.toml
1 parent 22e248a commit 8d6c8c8

File tree

3 files changed

+34
-34
lines changed

3 files changed

+34
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.8.0"
3+
version = "0.9.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/grad.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ function grad(fdm, f, x::AbstractArray{T}) where T <: Number
1818
return f(tmp)
1919
end
2020
end
21-
return dx
21+
return (dx, )
2222
end
2323

24-
grad(fdm, f, x::Real) = fdm(f, x)
25-
grad(fdm, f, x::Tuple) = grad(fdm, (xs...)->f(xs), x...)
24+
grad(fdm, f, x::Real) = (fdm(f, x), )
25+
grad(fdm, f, x::Tuple) = (grad(fdm, (xs...)->f(xs), x...), )
2626

2727
function grad(fdm, f, d::Dict{K, V}) where {K, V}
2828
dd = Dict{K, V}()
@@ -32,19 +32,19 @@ function grad(fdm, f, d::Dict{K, V}) where {K, V}
3232
tmp[k] = x
3333
return f(tmp)
3434
end
35-
dd[k] = grad(fdm, f′, v)
35+
dd[k] = grad(fdm, f′, v)[1]
3636
end
37-
return dd
37+
return (dd, )
3838
end
3939

4040
function grad(fdm, f, x)
4141
v, back = to_vec(x)
42-
return back(grad(fdm, x->f(back(v)), v))
42+
return (back(grad(fdm, x->f(back(v)), v)), )
4343
end
4444

4545
function grad(fdm, f, xs...)
4646
return ntuple(length(xs)) do k
47-
grad(fdm, x->f(replace_arg(x, xs, k)...), xs[k])
47+
grad(fdm, x->f(replace_arg(x, xs, k)...), xs[k])[1]
4848
end
4949
end
5050

@@ -57,17 +57,17 @@ Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `len
5757
function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; len::Int=length(f(x))) where {T <: Number}
5858
J = Matrix{float(T)}(undef, len, length(x))
5959
for d in 1:len
60-
gs = grad(fdm, x->f(x)[d], x)
60+
gs = grad(fdm, x->f(x)[d], x)[1]
6161
for k in 1:length(x)
6262
J[d, k] = gs[k]
6363
end
6464
end
65-
return J
65+
return (J, )
6666
end
6767

6868
function jacobian(fdm, f, xs...; len::Int=length(f(xs...)))
6969
return ntuple(length(xs)) do k
70-
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; len=len)
70+
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; len=len)[1]
7171
end
7272
end
7373

@@ -83,7 +83,7 @@ _jvp(fdm, f, x::Vector{<:Number}, ẋ::AV{<:Number}) = fdm(ε -> f(x .+ ε .* x
8383
8484
Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
8585
"""
86-
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x; len=length(ȳ))) *
86+
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x; len=length(ȳ))[1]) *
8787

8888
"""
8989
jvp(fdm, f, x, ẋ)
@@ -98,7 +98,7 @@ function jvp(fdm, f, (x, ẋ)::Tuple{Any, Any})
9898
end
9999
function jvp(fdm, f, xẋs::Tuple{Any, Any}...)
100100
x, ẋ = collect(zip(xẋs...))
101-
return jvp(fdm, xs->f(xs...), (x, ẋ))
101+
return jvp(fdm, xs->f(xs...)[1], (x, ẋ))
102102
end
103103

104104
"""
@@ -109,9 +109,9 @@ Compute an adjoint with any types of arguments for which [`to_vec`](@ref) is def
109109
function j′vp(fdm, f, ȳ, x)
110110
x_vec, vec_to_x = to_vec(x)
111111
ȳ_vec, _ = to_vec(ȳ)
112-
return vec_to_x(_j′vp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], ȳ_vec, x_vec))
112+
return (vec_to_x(_j′vp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], ȳ_vec, x_vec)), )
113113
end
114-
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)
114+
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1]
115115

116116
"""
117117
to_vec(x) -> Tuple{<:AbstractVector, <:Function}

test/grad.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ Base.length(x::DummyType) = size(x.X, 1)
2020
rng, fdm = MersenneTwister(123456), central_fdm(5, 1)
2121
x = randn(rng, T, 2)
2222
xc = copy(x)
23-
@test grad(fdm, x->sin(x[1]) + cos(x[2]), x) [cos(x[1]), -sin(x[2])]
23+
@test grad(fdm, x->sin(x[1]) + cos(x[2]), x)[1] [cos(x[1]), -sin(x[2])]
2424
@test xc == x
2525
end
2626

2727
function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact)
2828
xc = copy(x)
29-
@test jacobian(fdm, f, x; len=length(ȳ)) J_exact
30-
@test jacobian(fdm, f, x) == jacobian(fdm, f, x; len=length(ȳ))
29+
@test jacobian(fdm, f, x; len=length(ȳ))[1] J_exact
30+
@test jacobian(fdm, f, x)[1] == jacobian(fdm, f, x; len=length(ȳ))[1]
3131
@test _jvp(fdm, f, x, ẋ) J_exact *
3232
@test _j′vp(fdm, f, ȳ, x) transpose(J_exact) *
3333
@test xc == x
@@ -56,46 +56,46 @@ Base.length(x::DummyType) = size(x.X, 1)
5656
@testset "check multiple matrices" begin
5757
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
5858
jac_xs = jacobian(fdm, f1, x, y)
59-
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
60-
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
59+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)[1]
60+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)[1]
6161
end
6262

6363
@testset "check mixed scalar and matrices" begin
6464
x, y = rand(3, 3), 2
6565
jac_xs = jacobian(fdm, f1, x, y)
66-
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
67-
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
66+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)[1]
67+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)[1]
6868
end
6969
end
7070

7171
@testset "grad" begin
7272
@testset "check multiple matrices" begin
7373
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
7474
dxs = grad(fdm, f2, x, y)
75-
@test dxs[1] grad(fdm, x->f2(x, y), x)
76-
@test dxs[2] grad(fdm, y->f2(x, y), y)
75+
@test dxs[1] grad(fdm, x->f2(x, y), x)[1]
76+
@test dxs[2] grad(fdm, y->f2(x, y), y)[1]
7777
end
7878

7979
@testset "check mixed scalar & matrices" begin
8080
x, y = rand(rng, 3, 3), 2
8181
dxs = grad(fdm, f2, x, y)
82-
@test dxs[1] grad(fdm, x->f2(x, y), x)
83-
@test dxs[2] grad(fdm, y->f2(x, y), y)
82+
@test dxs[1] grad(fdm, x->f2(x, y), x)[1]
83+
@test dxs[2] grad(fdm, y->f2(x, y), y)[1]
8484
end
8585

8686
@testset "check tuple" begin
8787
x, y = rand(rng, 3, 3), 2
88-
dxs = grad(fdm, f3, (x, y))
89-
@test dxs[1] grad(fdm, x->f3((x, y)), x)
90-
@test dxs[2] grad(fdm, y->f3((x, y)), y)
88+
dxs = grad(fdm, f3, (x, y))[1]
89+
@test dxs[1] grad(fdm, x->f3((x, y)), x)[1]
90+
@test dxs[2] grad(fdm, y->f3((x, y)), y)[1]
9191
end
9292

9393
@testset "check dict" begin
9494
x, y = rand(rng, 3, 3), 2
9595
d = Dict(:x=>x, :y=>y)
96-
dxs = grad(fdm, f4, d)
97-
@test dxs[:x] grad(fdm, x->f3((x, y)), x)
98-
@test dxs[:y] grad(fdm, y->f3((x, y)), y)
96+
dxs = grad(fdm, f4, d)[1]
97+
@test dxs[:x] grad(fdm, x->f3((x, y)), x)[1]
98+
@test dxs[:y] grad(fdm, y->f3((x, y)), y)[1]
9999
end
100100
end
101101
end
@@ -168,8 +168,8 @@ Base.length(x::DummyType) = size(x.X, 1)
168168
x, y = randn(rng, T, N), randn(rng, T, M)
169169
= randn(rng, T, N + M)
170170
xy = vcat(x, y)
171-
x̄ȳ_manual = j′vp(fdm, xy->sin.(xy), z̄, xy)
172-
x̄ȳ_auto = j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y))
171+
x̄ȳ_manual = j′vp(fdm, xy->sin.(xy), z̄, xy)[1]
172+
x̄ȳ_auto = j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y))[1]
173173
x̄ȳ_multi = j′vp(fdm, (x, y)->sin.(vcat(x, y)), z̄, x, y)
174174
@test x̄ȳ_manual vcat(x̄ȳ_auto...)
175175
@test x̄ȳ_manual vcat(x̄ȳ_multi...)

0 commit comments

Comments
 (0)