Skip to content

Commit fc3fb1e

Browse files
committed
Fix up tests that don't work because of move
1 parent 30f4c03 commit fc3fb1e

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Test
1111

1212
export TestIterator
1313
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
14-
export
14+
export , rand_tangent
1515
export @maybe_inferred
1616

1717
__init__() = init_test_inferred_setting!()

test/iterator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
@testset "rand_tangent" begin
8989
data = randn(2, 3, 4)
9090
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
91-
∂iter = FiniteDifferences.rand_tangent(iter)
91+
∂iter = rand_tangent(iter)
9292
@test ∂iter isa typeof(iter)
9393
@test size(∂iter.data) == size(iter.data)
9494
@test eltype(∂iter.data) === eltype(iter.data)

test/rand_tangent.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
using FiniteDifferences: rand_tangent
2-
3-
@testset "generate_tangent" begin
1+
# Test struct for `rand_tangent` and `difference`.
2+
struct Bar
3+
a::Float64
4+
b::Int
5+
c::Any
6+
end
7+
@testset "rand_tangent" begin
48
rng = MersenneTwister(123456)
59

610
@testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [
@@ -13,9 +17,9 @@ using FiniteDifferences: rand_tangent
1317
(4, NoTangent),
1418
(FiniteDifferences, NoTangent), # Module object
1519
# Types (not instances of type)
16-
(Foo, NoTangent),
17-
(Union{Int, Foo}, NoTangent),
18-
(Union{Int, Foo}, NoTangent),
20+
(Bar, NoTangent),
21+
(Union{Int, Bar}, NoTangent),
22+
(Union{Int, Bar}, NoTangent),
1923
(Vector, NoTangent),
2024
(Vector{Float64}, NoTangent),
2125
(Integer, NoTangent),
@@ -49,8 +53,8 @@ using FiniteDifferences: rand_tangent
4953
((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
5054

5155
# structs.
52-
(Foo(5.0, 4, rand(rng, 3)), Tangent{Foo}),
53-
(Foo(4.0, 3, Foo(5.0, 2, 4)), Tangent{Foo}),
56+
(Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}),
57+
(Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}),
5458
(sin, NoTangent),
5559
# all fields NoTangent implies NoTangent
5660
(Pair(:a, "b"), NoTangent),
@@ -66,14 +70,6 @@ using FiniteDifferences: rand_tangent
6670
Diagonal(randn(2)),
6771
Tangent{Diagonal{Float64, Vector{Float64}}},
6872
),
69-
(
70-
SVector{2, Float64}(1.0, 2.0),
71-
Tangent{typeof(SVector{2, Float64}(1.0, 2.0))},
72-
),
73-
(
74-
SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0),
75-
Tangent{typeof(SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0))},
76-
),
7773
(
7874
Symmetric(randn(2, 2)),
7975
Tangent{Symmetric{Float64, Matrix{Float64}}},
@@ -93,7 +89,7 @@ using FiniteDifferences: rand_tangent
9389
end
9490

9591
@testset "compsition of addition" begin
96-
x = Foo(1.5, 2, Foo(1.1, 3, [1.7, 1.4, 0.9]))
92+
x = Bar(1.5, 2, Bar(1.1, 3, [1.7, 1.4, 0.9]))
9793
@test x + rand_tangent(x) isa typeof(x)
9894
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
9995
end

0 commit comments

Comments
 (0)