|
| 1 | +using FiniteDifferences: rand_tangent |
| 2 | + |
| 3 | +@testset "generate_tangent" begin |
| 4 | + rng = MersenneTwister(123456) |
| 5 | + |
| 6 | + @testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [ |
| 7 | + |
| 8 | + # Things without sensible tangents. |
| 9 | + ("hi", DoesNotExist), |
| 10 | + ('a', DoesNotExist), |
| 11 | + (:a, DoesNotExist), |
| 12 | + (true, DoesNotExist), |
| 13 | + (4, DoesNotExist), |
| 14 | + |
| 15 | + # Numbers. |
| 16 | + (5.0, Float64), |
| 17 | + (5.0 + 0.4im, Complex{Float64}), |
| 18 | + |
| 19 | + # StridedArrays. |
| 20 | + (randn(Float32, 3), Vector{Float32}), |
| 21 | + (randn(Complex{Float64}, 2), Vector{Complex{Float64}}), |
| 22 | + (randn(5, 4), Matrix{Float64}), |
| 23 | + (randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}), |
| 24 | + ([randn(5, 4), 4.0], Vector{Any}), |
| 25 | + |
| 26 | + # Tuples. |
| 27 | + ((4.0, ), Composite{Tuple{Float64}}), |
| 28 | + ((5.0, randn(3)), Composite{Tuple{Float64, Vector{Float64}}}), |
| 29 | + |
| 30 | + # NamedTuples. |
| 31 | + ((a=4.0, ), Composite{NamedTuple{(:a,), Tuple{Float64}}}), |
| 32 | + ((a=5.0, b=1), Composite{NamedTuple{(:a, :b), Tuple{Float64, Int}}}), |
| 33 | + |
| 34 | + # structs. |
| 35 | + (sin, typeof(NO_FIELDS)), |
| 36 | + (Foo(5.0, 4, rand(rng, 3)), Composite{Foo}), |
| 37 | + (Foo(4.0, 3, Foo(5.0, 2, 4)), Composite{Foo}), |
| 38 | + |
| 39 | + # LinearAlgebra types (also just structs). |
| 40 | + ( |
| 41 | + UpperTriangular(randn(3, 3)), |
| 42 | + Composite{UpperTriangular{Float64, Matrix{Float64}}}, |
| 43 | + ), |
| 44 | + ( |
| 45 | + Diagonal(randn(2)), |
| 46 | + Composite{Diagonal{Float64, Vector{Float64}}}, |
| 47 | + ), |
| 48 | + ( |
| 49 | + SVector{2, Float64}(1.0, 2.0), |
| 50 | + Composite{typeof(SVector{2, Float64}(1.0, 2.0))}, |
| 51 | + ), |
| 52 | + ( |
| 53 | + SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0), |
| 54 | + Composite{typeof(SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0))}, |
| 55 | + ), |
| 56 | + ( |
| 57 | + Symmetric(randn(2, 2)), |
| 58 | + Composite{Symmetric{Float64, Matrix{Float64}}}, |
| 59 | + ), |
| 60 | + ( |
| 61 | + Hermitian(randn(ComplexF64, 1, 1)), |
| 62 | + Composite{Hermitian{ComplexF64, Matrix{ComplexF64}}}, |
| 63 | + ), |
| 64 | + ( |
| 65 | + Adjoint(randn(ComplexF64, 3, 3)), |
| 66 | + Composite{Adjoint{ComplexF64, Matrix{ComplexF64}}}, |
| 67 | + ), |
| 68 | + ( |
| 69 | + Transpose(randn(3)), |
| 70 | + Composite{Transpose{Float64, Vector{Float64}}}, |
| 71 | + ), |
| 72 | + ] |
| 73 | + @test rand_tangent(rng, x) isa T_tangent |
| 74 | + @test rand_tangent(x) isa T_tangent |
| 75 | + @test x + rand_tangent(rng, x) isa typeof(x) |
| 76 | + end |
| 77 | + |
| 78 | + # Ensure struct fallback errors for non-struct types. |
| 79 | + @test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0) |
| 80 | +end |
0 commit comments