Skip to content

Commit 93b6058

Browse files
Implement rand_tangent and difference (#91)
* Implement rand_tangent and differnece * Bump docs build version * Extra tests * Test on latest version * Remove old Manifest * Add tests / remove untested feature * Moves Foo to runtests * Modify testsets * Regenerate Manifest * Fix typo * Add TODO Co-authored-by: Nick Robinson <npr251@gmail.com> * Fix docs hopefully * Don't export difference or rand_tangent * Bump patch * Fix tests Co-authored-by: Nick Robinson <npr251@gmail.com>
1 parent 947382a commit 93b6058

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

src/rand_tangent.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
rand_tangent([rng::AbstractRNG,] x)
3+
4+
Returns a randomly generated tangent vector appropriate for the primal value `x`.
5+
"""
6+
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)
7+
8+
rand_tangent(rng::AbstractRNG, x::Symbol) = DoesNotExist()
9+
rand_tangent(rng::AbstractRNG, x::AbstractChar) = DoesNotExist()
10+
rand_tangent(rng::AbstractRNG, x::AbstractString) = DoesNotExist()
11+
12+
rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()
13+
14+
rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)
15+
16+
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
17+
18+
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
19+
return Composite{T}(rand_tangent.(Ref(rng), x)...)
20+
end
21+
22+
function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
23+
return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...)
24+
end
25+
26+
function rand_tangent(rng::AbstractRNG, x::T) where {T}
27+
if !isstructtype(T)
28+
throw(ArgumentError("Non-struct types are not supported by this fallback."))
29+
end
30+
31+
field_names = fieldnames(T)
32+
if length(field_names) > 0
33+
tangents = map(field_names) do field_name
34+
rand_tangent(rng, getfield(x, field_name))
35+
end
36+
return Composite{T}(; NamedTuple{field_names}(tangents)...)
37+
else
38+
return NO_FIELDS
39+
end
40+
end

test/rand_tangent.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using FiniteDifferences: rand_tangent
2+
3+
@testset "generate_tangent" begin
4+
rng = MersenneTwister(123456)
5+
6+
@testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [
7+
8+
# Things without sensible tangents.
9+
("hi", DoesNotExist),
10+
('a', DoesNotExist),
11+
(:a, DoesNotExist),
12+
(true, DoesNotExist),
13+
(4, DoesNotExist),
14+
15+
# Numbers.
16+
(5.0, Float64),
17+
(5.0 + 0.4im, Complex{Float64}),
18+
19+
# StridedArrays.
20+
(randn(Float32, 3), Vector{Float32}),
21+
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),
22+
(randn(5, 4), Matrix{Float64}),
23+
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
24+
([randn(5, 4), 4.0], Vector{Any}),
25+
26+
# Tuples.
27+
((4.0, ), Composite{Tuple{Float64}}),
28+
((5.0, randn(3)), Composite{Tuple{Float64, Vector{Float64}}}),
29+
30+
# NamedTuples.
31+
((a=4.0, ), Composite{NamedTuple{(:a,), Tuple{Float64}}}),
32+
((a=5.0, b=1), Composite{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
33+
34+
# structs.
35+
(sin, typeof(NO_FIELDS)),
36+
(Foo(5.0, 4, rand(rng, 3)), Composite{Foo}),
37+
(Foo(4.0, 3, Foo(5.0, 2, 4)), Composite{Foo}),
38+
39+
# LinearAlgebra types (also just structs).
40+
(
41+
UpperTriangular(randn(3, 3)),
42+
Composite{UpperTriangular{Float64, Matrix{Float64}}},
43+
),
44+
(
45+
Diagonal(randn(2)),
46+
Composite{Diagonal{Float64, Vector{Float64}}},
47+
),
48+
(
49+
SVector{2, Float64}(1.0, 2.0),
50+
Composite{typeof(SVector{2, Float64}(1.0, 2.0))},
51+
),
52+
(
53+
SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0),
54+
Composite{typeof(SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0))},
55+
),
56+
(
57+
Symmetric(randn(2, 2)),
58+
Composite{Symmetric{Float64, Matrix{Float64}}},
59+
),
60+
(
61+
Hermitian(randn(ComplexF64, 1, 1)),
62+
Composite{Hermitian{ComplexF64, Matrix{ComplexF64}}},
63+
),
64+
(
65+
Adjoint(randn(ComplexF64, 3, 3)),
66+
Composite{Adjoint{ComplexF64, Matrix{ComplexF64}}},
67+
),
68+
(
69+
Transpose(randn(3)),
70+
Composite{Transpose{Float64, Vector{Float64}}},
71+
),
72+
]
73+
@test rand_tangent(rng, x) isa T_tangent
74+
@test rand_tangent(x) isa T_tangent
75+
@test x + rand_tangent(rng, x) isa typeof(x)
76+
end
77+
78+
# Ensure struct fallback errors for non-struct types.
79+
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
80+
end

0 commit comments

Comments
 (0)