Skip to content

Commit 87f44c8

Browse files
simeonschaubwesselb
authored andcommitted
Relax type constraints to allow complex jvp (#43)
* relax type constraints to allow complex jvp * fix docstrings, allow to_vec(::Vector{<:Number}) * another round of fixing docstrings * add tests * tests for to_vec(::Complex), correct to_vec for Adjoint * delete stray comment * use `@testset "bla" for ...` * remove accidental indents
1 parent 08960c9 commit 87f44c8

File tree

2 files changed

+66
-55
lines changed

2 files changed

+66
-55
lines changed

src/grad.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
export grad, jacobian, jvp, j′vp, to_vec
22

33
"""
4-
grad(fdm, f, x::AbstractVector)
4+
grad(fdm, f, x::Vector{<:Number})
55
66
Approximate the gradient of `f` at `x` using `fdm`. Assumes that `f(x)` is scalar.
77
"""
8-
function grad(fdm, f, x::Vector{T}) where T<:Real
8+
function grad(fdm, f, x::Vector{T}) where T<:Number
99
v, dx, tmp = fill(zero(T), size(x)), similar(x), similar(x)
1010
for n in eachindex(x)
1111
v[n] = one(T)
@@ -21,34 +21,34 @@ function grad(fdm, f, x::Vector{T}) where T<:Real
2121
end
2222

2323
"""
24-
jacobian(fdm, f, x::AbstractVector{<:Real}, D::Int)
25-
jacobian(fdm, f, x::AbstractVector{<:Real})
24+
jacobian(fdm, f, x::Vector{<:Number}, D::Int)
25+
jacobian(fdm, f, x::Vector{<:Number})
2626
2727
Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `D` vector. If
2828
`D` is not provided, then `f(x)` is computed once to determine the output size.
2929
"""
30-
function jacobian(fdm, f, x::Vector{T}, D::Int) where {T<:Real}
30+
function jacobian(fdm, f, x::Vector{T}, D::Int) where {T<:Number}
3131
J = Matrix{T}(undef, D, length(x))
3232
for d in 1:D
3333
J[d, :] = grad(fdm, x->f(x)[d], x)
3434
end
3535
return J
3636
end
37-
jacobian(fdm, f, x::Vector{<:Real}) = jacobian(fdm, f, x, length(f(x)))
37+
jacobian(fdm, f, x::Vector{<:Number}) = jacobian(fdm, f, x, length(f(x)))
3838

3939
"""
40-
_jvp(fdm, f, x::Vector{<:Real}, ẋ::AbstractVector{<:Real})
40+
_jvp(fdm, f, x::Vector{<:Number}, ẋ::AbstractVector{<:Number})
4141
4242
Convenience function to compute `jacobian(f, x) * ẋ`.
4343
"""
44-
_jvp(fdm, f, x::Vector{<:Real}, ẋ::AV{<:Real}) = fdm-> f(x .+ ε .* ẋ), zero(eltype(x)))
44+
_jvp(fdm, f, x::Vector{<:Number}, ẋ::AV{<:Number}) = fdm-> f(x .+ ε .* ẋ), zero(eltype(x)))
4545

4646
"""
47-
_j′vp(fdm, f, ȳ::AbstractVector{<:Real}, x::Vector{<:Real})
47+
_j′vp(fdm, f, ȳ::AbstractVector{<:Number}, x::Vector{<:Number})
4848
49-
Convenience function to compute `jacobian(f, x)' * ȳ`.
49+
Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
5050
"""
51-
_j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x, length(ȳ))' *
51+
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x, length(ȳ))) *
5252

5353
"""
5454
jvp(fdm, f, x, ẋ)
@@ -83,10 +83,10 @@ j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)
8383
8484
Transform `x` into a `Vector`, and return a closure which inverts the transformation.
8585
"""
86-
to_vec(x::Real) = ([x], first)
86+
to_vec(x::Number) = ([x], first)
8787

8888
# Vectors
89-
to_vec(x::Vector{<:Real}) = (x, identity)
89+
to_vec(x::Vector{<:Number}) = (x, identity)
9090
function to_vec(x::Vector)
9191
x_vecs_and_backs = map(to_vec, x)
9292
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
@@ -97,7 +97,7 @@ function to_vec(x::Vector)
9797
end
9898

9999
# Arrays
100-
to_vec(x::Array{<:Real}) = vec(x), x_vec->reshape(x_vec, size(x))
100+
to_vec(x::Array{<:Number}) = vec(x), x_vec->reshape(x_vec, size(x))
101101
function to_vec(x::Array)
102102
x_vec, back = to_vec(reshape(x, :))
103103
return x_vec, x_vec->reshape(back(x_vec), size(x))
@@ -111,9 +111,11 @@ end
111111
to_vec(x::Symmetric) = vec(Matrix(x)), x_vec->Symmetric(reshape(x_vec, size(x)))
112112
to_vec(X::Diagonal) = vec(Matrix(X)), x_vec->Diagonal(reshape(x_vec, size(X)...))
113113

114-
function to_vec(X::T) where T<:Union{Adjoint,Transpose}
115-
U = T.name.wrapper
116-
return vec(Matrix(X)), x_vec->U(permutedims(reshape(x_vec, size(X))))
114+
function to_vec(X::Transpose)
115+
return vec(Matrix(X)), x_vec->Transpose(permutedims(reshape(x_vec, size(X))))
116+
end
117+
function to_vec(X::Adjoint)
118+
return vec(Matrix(X)), x_vec->Adjoint(conj!(permutedims(reshape(x_vec, size(X)))))
117119
end
118120

119121
# Non-array data structures

test/grad.jl

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ Base.length(x::DummyType) = size(x.X, 1)
1616

1717
@testset "grad" begin
1818

19-
@testset "grad" begin
19+
@testset "grad(::$T)" for T in (Float64, ComplexF64)
2020
rng, fdm = MersenneTwister(123456), central_fdm(5, 1)
21-
x = randn(rng, 2)
21+
x = randn(rng, T, 2)
2222
xc = copy(x)
2323
@test grad(fdm, x->sin(x[1]) + cos(x[2]), x) [cos(x[1]), -sin(x[2])]
2424
@test xc == x
@@ -29,13 +29,13 @@ Base.length(x::DummyType) = size(x.X, 1)
2929
@test jacobian(fdm, f, x, length(ȳ)) J_exact
3030
@test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ))
3131
@test _jvp(fdm, f, x, ẋ) J_exact *
32-
@test _j′vp(fdm, f, ȳ, x) J_exact' *
32+
@test _j′vp(fdm, f, ȳ, x) transpose(J_exact) *
3333
@test xc == x
3434
end
3535

36-
@testset "jacobian / _jvp / _j′vp" begin
36+
@testset "jacobian / _jvp / _j′vp (::$T)" for T in (Float64, ComplexF64)
3737
rng, P, Q, fdm = MersenneTwister(123456), 3, 2, central_fdm(5, 1)
38-
ȳ, A, x, ẋ = randn(rng, P), randn(rng, P, Q), randn(rng, Q), randn(rng, Q)
38+
ȳ, A, x, ẋ = randn(rng, T, P), randn(rng, T, P, Q), randn(rng, T, Q), randn(rng, T, Q)
3939
Ac = copy(A)
4040

4141
check_jac_and_jvp_and_j′vp(fdm, x->A * x, ȳ, x, ẋ, A)
@@ -51,45 +51,54 @@ Base.length(x::DummyType) = size(x.X, 1)
5151
return nothing
5252
end
5353

54-
@testset "to_vec" begin
55-
test_to_vec(1.0)
56-
test_to_vec(1)
57-
test_to_vec(randn(3))
58-
test_to_vec(randn(5, 11))
59-
test_to_vec(randn(13, 17, 19))
60-
test_to_vec(randn(13, 0, 19))
61-
test_to_vec([1.0, randn(2), randn(1), 2.0])
62-
test_to_vec([randn(5, 4, 3), (5, 4, 3), 2.0])
63-
test_to_vec(reshape([1.0, randn(5, 4, 3), randn(4, 3), 2.0], 2, 2))
64-
test_to_vec(UpperTriangular(randn(13, 13)))
65-
test_to_vec(Symmetric(randn(11, 11)))
66-
test_to_vec(Diagonal(randn(7)))
67-
test_to_vec(DummyType(randn(2, 9)))
68-
69-
@testset "$T" for T in (Adjoint, Transpose)
70-
test_to_vec(T(randn(4, 4)))
71-
test_to_vec(T(randn(6)))
72-
test_to_vec(T(randn(2, 5)))
54+
@testset "to_vec(::$T)" for T in (Float64, ComplexF64)
55+
if T == Float64
56+
test_to_vec(1.0)
57+
test_to_vec(1)
58+
else
59+
test_to_vec(.7 + .8im)
60+
test_to_vec(1 + 2im)
7361
end
74-
62+
test_to_vec(randn(T, 3))
63+
test_to_vec(randn(T, 5, 11))
64+
test_to_vec(randn(T, 13, 17, 19))
65+
test_to_vec(randn(T, 13, 0, 19))
66+
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0])
67+
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0])
68+
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2))
69+
test_to_vec(UpperTriangular(randn(T, 13, 13)))
70+
test_to_vec(Symmetric(randn(T, 11, 11)))
71+
test_to_vec(Diagonal(randn(T, 7)))
72+
test_to_vec(DummyType(randn(T, 2, 9)))
73+
74+
@testset "$Op" for Op in (Adjoint, Transpose)
75+
test_to_vec(Op(randn(T, 4, 4)))
76+
test_to_vec(Op(randn(T, 6)))
77+
test_to_vec(Op(randn(T, 2, 5)))
78+
end
79+
7580
@testset "Tuples" begin
7681
test_to_vec((5, 4))
77-
test_to_vec((5, randn(5)))
78-
test_to_vec((randn(4), randn(4, 3, 2), 1))
79-
test_to_vec((5, randn(4, 3, 2), UpperTriangular(randn(4, 4)), 2.5))
80-
test_to_vec(((6, 5), 3, randn(3, 2, 0, 1)))
81-
test_to_vec((DummyType(randn(2, 7)), DummyType(randn(3, 9))))
82-
test_to_vec((DummyType(randn(3, 2)), randn(11, 8)))
82+
test_to_vec((5, randn(T, 5)))
83+
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1))
84+
test_to_vec((5, randn(T, 4, 3, 2), UpperTriangular(randn(T, 4, 4)), 2.5))
85+
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)))
86+
test_to_vec((DummyType(randn(T, 2, 7)), DummyType(randn(T, 3, 9))))
87+
test_to_vec((DummyType(randn(T, 3, 2)), randn(T, 11, 8)))
8388
end
8489
@testset "Dictionary" begin
85-
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)))
90+
if T == Float64
91+
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)))
92+
else
93+
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)))
94+
end
8695
end
8796
end
8897

89-
@testset "jvp" begin
98+
@testset "jvp(::$T)" for T in (Float64, ComplexF64)
9099
rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1)
91-
x, y = randn(rng, N), randn(rng, M)
92-
ẋ, ẏ = randn(rng, N), randn(rng, M)
100+
x, y = randn(rng, T, N), randn(rng, T, M)
101+
ẋ, ẏ = randn(rng, T, N), randn(rng, T, M)
93102
xy, ẋẏ = vcat(x, y), vcat(ẋ, ẏ)
94103
ż_manual = _jvp(fdm, (xy)->sum(sin, xy), xy, ẋẏ)[1]
95104
ż_auto = jvp(fdm, x->sum(sin, x[1]) + sum(sin, x[2]), ((x, y), (ẋ, ẏ)))
@@ -98,10 +107,10 @@ Base.length(x::DummyType) = size(x.X, 1)
98107
@test ż_manual ż_multi
99108
end
100109

101-
@testset "j′vp" begin
110+
@testset "j′vp(::$T)" for T in (Float64, ComplexF64)
102111
rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1)
103-
x, y = randn(rng, N), randn(rng, M)
104-
= randn(rng, N + M)
112+
x, y = randn(rng, T, N), randn(rng, T, M)
113+
= randn(rng, T, N + M)
105114
xy = vcat(x, y)
106115
x̄ȳ_manual = j′vp(fdm, xy->sin.(xy), z̄, xy)
107116
x̄ȳ_auto = j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y))

0 commit comments

Comments
 (0)