Skip to content

Commit 492364c

Browse files
authored
Merge pull request #189 from JuliaDiff/ox/rand_tangent
Bring back rand_tangent
2 parents 814d2bb + fc3fb1e commit 492364c

File tree

5 files changed

+173
-4
lines changed

5 files changed

+173
-4
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,16 @@ using LinearAlgebra
99
using Random
1010
using Test
1111

12-
import FiniteDifferences: rand_tangent
13-
1412
export TestIterator
1513
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
16-
export
14+
export , rand_tangent
1715
export @maybe_inferred
1816

1917
__init__() = init_test_inferred_setting!()
2018

2119
include("global_config.jl")
2220

21+
include("rand_tangent.jl")
2322
include("generate_tangent.jl")
2423
include("data_generation.jl")
2524
include("iterator.jl")

src/rand_tangent.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
rand_tangent([rng::AbstractRNG,] x)
3+
4+
Returns a arbitary tangent vector _appropriate_ for the primal value `x`.
5+
Note that despite the name, no promises on the statistical randomness are made.
6+
Rather it is an arbitary value, that is generated using the `rng`.
7+
"""
8+
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)
9+
10+
rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
11+
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
12+
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()
13+
14+
rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()
15+
16+
# Try and make nice numbers with short decimal representations for good error messages
17+
# while also not biasing the sample space too much
18+
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Number}
19+
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
20+
return round(9 * randn(rng, T), sigdigits=5, base=2)
21+
end
22+
rand_tangent(rng::AbstractRNG, x::Float64) = rand(rng, -9:0.01:9)
23+
function rand_tangent(rng::AbstractRNG, x::ComplexF64)
24+
return ComplexF64(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9))
25+
end
26+
27+
#BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should
28+
29+
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
30+
rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), sigdigits=5, base=2)
31+
32+
rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(x[1]))
33+
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
34+
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
35+
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))
36+
37+
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
38+
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
39+
end
40+
41+
function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
42+
return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...)
43+
end
44+
45+
function rand_tangent(rng::AbstractRNG, x::T) where {T}
46+
if !isstructtype(T)
47+
throw(ArgumentError("Non-struct types are not supported by this fallback."))
48+
end
49+
50+
field_names = fieldnames(T)
51+
tangents = map(field_names) do field_name
52+
rand_tangent(rng, getfield(x, field_name))
53+
end
54+
if all(tangent isa NoTangent for tangent in tangents)
55+
# if none of my fields can be perturbed then I can't be perturbed
56+
return NoTangent()
57+
else
58+
Tangent{T}(; NamedTuple{field_names}(tangents)...)
59+
end
60+
end
61+
62+
rand_tangent(rng::AbstractRNG, ::Type) = NoTangent()
63+
rand_tangent(rng::AbstractRNG, ::Module) = NoTangent()

test/iterator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
@testset "rand_tangent" begin
8989
data = randn(2, 3, 4)
9090
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
91-
∂iter = FiniteDifferences.rand_tangent(iter)
91+
∂iter = rand_tangent(iter)
9292
@test ∂iter isa typeof(iter)
9393
@test size(∂iter.data) == size(iter.data)
9494
@test eltype(∂iter.data) === eltype(iter.data)

test/rand_tangent.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Test struct for `rand_tangent` and `difference`.
2+
struct Bar
3+
a::Float64
4+
b::Int
5+
c::Any
6+
end
7+
@testset "rand_tangent" begin
8+
rng = MersenneTwister(123456)
9+
10+
@testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [
11+
12+
# Things without sensible tangents.
13+
("hi", NoTangent),
14+
('a', NoTangent),
15+
(:a, NoTangent),
16+
(true, NoTangent),
17+
(4, NoTangent),
18+
(FiniteDifferences, NoTangent), # Module object
19+
# Types (not instances of type)
20+
(Bar, NoTangent),
21+
(Union{Int, Bar}, NoTangent),
22+
(Union{Int, Bar}, NoTangent),
23+
(Vector, NoTangent),
24+
(Vector{Float64}, NoTangent),
25+
(Integer, NoTangent),
26+
(Type{<:Real}, NoTangent),
27+
28+
# Numbers.
29+
(5.0, Float64),
30+
(5.0 + 0.4im, Complex{Float64}),
31+
(big(5.0), BigFloat),
32+
33+
# StridedArrays.
34+
(fill(randn(Float32)), Array{Float32, 0}),
35+
(fill(randn(Float64)), Array{Float64, 0}),
36+
(randn(Float32, 3), Vector{Float32}),
37+
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),
38+
(randn(5, 4), Matrix{Float64}),
39+
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
40+
([randn(5, 4), 4.0], Vector{Any}),
41+
42+
# Wrapper Arrays
43+
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
44+
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),
45+
46+
47+
# Tuples.
48+
((4.0, ), Tangent{Tuple{Float64}}),
49+
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
50+
51+
# NamedTuples.
52+
((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}),
53+
((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
54+
55+
# structs.
56+
(Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}),
57+
(Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}),
58+
(sin, NoTangent),
59+
# all fields NoTangent implies NoTangent
60+
(Pair(:a, "b"), NoTangent),
61+
(1:10, NoTangent),
62+
(1:2:10, NoTangent),
63+
64+
# LinearAlgebra types (also just structs).
65+
(
66+
UpperTriangular(randn(3, 3)),
67+
Tangent{UpperTriangular{Float64, Matrix{Float64}}},
68+
),
69+
(
70+
Diagonal(randn(2)),
71+
Tangent{Diagonal{Float64, Vector{Float64}}},
72+
),
73+
(
74+
Symmetric(randn(2, 2)),
75+
Tangent{Symmetric{Float64, Matrix{Float64}}},
76+
),
77+
(
78+
Hermitian(randn(ComplexF64, 1, 1)),
79+
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
80+
),
81+
]
82+
@test rand_tangent(rng, x) isa T_tangent
83+
@test rand_tangent(x) isa T_tangent
84+
end
85+
86+
@testset "erroring cases" begin
87+
# Ensure struct fallback errors for non-struct types.
88+
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
89+
end
90+
91+
@testset "compsition of addition" begin
92+
x = Bar(1.5, 2, Bar(1.1, 3, [1.7, 1.4, 0.9]))
93+
@test x + rand_tangent(x) isa typeof(x)
94+
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
95+
end
96+
97+
# Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short
98+
VERSION > v"1.6" && @testset "niceness of printing" begin
99+
for i in 1:50
100+
@test length(string(rand_tangent(1.0))) <= 6
101+
@test length(string(rand_tangent(1.0 + 1.0im))) <= 12
102+
@test length(string(rand_tangent(1f0))) <= 12
103+
@test length(string(rand_tangent(big"1.0"))) <= 12
104+
end
105+
end
106+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ChainRulesTestUtils.TEST_INFERRED[] = true
1414
include("check_result.jl")
1515
include("testers.jl")
1616
include("data_generation.jl")
17+
include("rand_tangent.jl")
1718

1819
include("deprecated.jl")
1920
end

0 commit comments

Comments
 (0)