@@ -16,9 +16,9 @@ Base.length(x::DummyType) = size(x.X, 1)
16
16
17
17
@testset " grad" begin
18
18
19
- @testset " grad" begin
19
+ @testset " grad(:: $T ) " for T in (Float64, ComplexF64)
20
20
rng, fdm = MersenneTwister (123456 ), central_fdm (5 , 1 )
21
- x = randn (rng, 2 )
21
+ x = randn (rng, T, 2 )
22
22
xc = copy (x)
23
23
@test grad (fdm, x-> sin (x[1 ]) + cos (x[2 ]), x) ≈ [cos (x[1 ]), - sin (x[2 ])]
24
24
@test xc == x
@@ -29,13 +29,13 @@ Base.length(x::DummyType) = size(x.X, 1)
29
29
@test jacobian (fdm, f, x, length (ȳ)) ≈ J_exact
30
30
@test jacobian (fdm, f, x) == jacobian (fdm, f, x, length (ȳ))
31
31
@test _jvp (fdm, f, x, ẋ) ≈ J_exact * ẋ
32
- @test _j′vp (fdm, f, ȳ, x) ≈ J_exact' * ȳ
32
+ @test _j′vp (fdm, f, ȳ, x) ≈ transpose ( J_exact) * ȳ
33
33
@test xc == x
34
34
end
35
35
36
- @testset " jacobian / _jvp / _j′vp" begin
36
+ @testset " jacobian / _jvp / _j′vp (:: $T ) " for T in (Float64, ComplexF64)
37
37
rng, P, Q, fdm = MersenneTwister (123456 ), 3 , 2 , central_fdm (5 , 1 )
38
- ȳ, A, x, ẋ = randn (rng, P), randn (rng, P, Q), randn (rng, Q), randn (rng, Q)
38
+ ȳ, A, x, ẋ = randn (rng, T, P), randn (rng, T, P, Q), randn (rng, T, Q), randn (rng, T , Q)
39
39
Ac = copy (A)
40
40
41
41
check_jac_and_jvp_and_j′vp (fdm, x-> A * x, ȳ, x, ẋ, A)
@@ -51,45 +51,54 @@ Base.length(x::DummyType) = size(x.X, 1)
51
51
return nothing
52
52
end
53
53
54
- @testset " to_vec" begin
55
- test_to_vec (1.0 )
56
- test_to_vec (1 )
57
- test_to_vec (randn (3 ))
58
- test_to_vec (randn (5 , 11 ))
59
- test_to_vec (randn (13 , 17 , 19 ))
60
- test_to_vec (randn (13 , 0 , 19 ))
61
- test_to_vec ([1.0 , randn (2 ), randn (1 ), 2.0 ])
62
- test_to_vec ([randn (5 , 4 , 3 ), (5 , 4 , 3 ), 2.0 ])
63
- test_to_vec (reshape ([1.0 , randn (5 , 4 , 3 ), randn (4 , 3 ), 2.0 ], 2 , 2 ))
64
- test_to_vec (UpperTriangular (randn (13 , 13 )))
65
- test_to_vec (Symmetric (randn (11 , 11 )))
66
- test_to_vec (Diagonal (randn (7 )))
67
- test_to_vec (DummyType (randn (2 , 9 )))
68
-
69
- @testset " $T " for T in (Adjoint, Transpose)
70
- test_to_vec (T (randn (4 , 4 )))
71
- test_to_vec (T (randn (6 )))
72
- test_to_vec (T (randn (2 , 5 )))
54
+ @testset " to_vec(::$T )" for T in (Float64, ComplexF64)
55
+ if T == Float64
56
+ test_to_vec (1.0 )
57
+ test_to_vec (1 )
58
+ else
59
+ test_to_vec (.7 + .8im )
60
+ test_to_vec (1 + 2im )
73
61
end
74
-
62
+ test_to_vec (randn (T, 3 ))
63
+ test_to_vec (randn (T, 5 , 11 ))
64
+ test_to_vec (randn (T, 13 , 17 , 19 ))
65
+ test_to_vec (randn (T, 13 , 0 , 19 ))
66
+ test_to_vec ([1.0 , randn (T, 2 ), randn (T, 1 ), 2.0 ])
67
+ test_to_vec ([randn (T, 5 , 4 , 3 ), (5 , 4 , 3 ), 2.0 ])
68
+ test_to_vec (reshape ([1.0 , randn (T, 5 , 4 , 3 ), randn (T, 4 , 3 ), 2.0 ], 2 , 2 ))
69
+ test_to_vec (UpperTriangular (randn (T, 13 , 13 )))
70
+ test_to_vec (Symmetric (randn (T, 11 , 11 )))
71
+ test_to_vec (Diagonal (randn (T, 7 )))
72
+ test_to_vec (DummyType (randn (T, 2 , 9 )))
73
+
74
+ @testset " $Op " for Op in (Adjoint, Transpose)
75
+ test_to_vec (Op (randn (T, 4 , 4 )))
76
+ test_to_vec (Op (randn (T, 6 )))
77
+ test_to_vec (Op (randn (T, 2 , 5 )))
78
+ end
79
+
75
80
@testset " Tuples" begin
76
81
test_to_vec ((5 , 4 ))
77
- test_to_vec ((5 , randn (5 )))
78
- test_to_vec ((randn (4 ), randn (4 , 3 , 2 ), 1 ))
79
- test_to_vec ((5 , randn (4 , 3 , 2 ), UpperTriangular (randn (4 , 4 )), 2.5 ))
80
- test_to_vec (((6 , 5 ), 3 , randn (3 , 2 , 0 , 1 )))
81
- test_to_vec ((DummyType (randn (2 , 7 )), DummyType (randn (3 , 9 ))))
82
- test_to_vec ((DummyType (randn (3 , 2 )), randn (11 , 8 )))
82
+ test_to_vec ((5 , randn (T, 5 )))
83
+ test_to_vec ((randn (T, 4 ), randn (T, 4 , 3 , 2 ), 1 ))
84
+ test_to_vec ((5 , randn (T, 4 , 3 , 2 ), UpperTriangular (randn (T, 4 , 4 )), 2.5 ))
85
+ test_to_vec (((6 , 5 ), 3 , randn (T, 3 , 2 , 0 , 1 )))
86
+ test_to_vec ((DummyType (randn (T, 2 , 7 )), DummyType (randn (T, 3 , 9 ))))
87
+ test_to_vec ((DummyType (randn (T, 3 , 2 )), randn (T, 11 , 8 )))
83
88
end
84
89
@testset " Dictionary" begin
85
- test_to_vec (Dict (:a => 5 , :b => randn (10 , 11 ), :c => (5 , 4 , 3 )))
90
+ if T == Float64
91
+ test_to_vec (Dict (:a => 5 , :b => randn (10 , 11 ), :c => (5 , 4 , 3 )))
92
+ else
93
+ test_to_vec (Dict (:a => 3 + 2im , :b => randn (T, 10 , 11 ), :c => (5 + im, 2 - im, 1 + im)))
94
+ end
86
95
end
87
96
end
88
97
89
- @testset " jvp" begin
98
+ @testset " jvp(:: $T ) " for T in (Float64, ComplexF64)
90
99
rng, N, M, fdm = MersenneTwister (123456 ), 2 , 3 , central_fdm (5 , 1 )
91
- x, y = randn (rng, N), randn (rng, M)
92
- ẋ, ẏ = randn (rng, N), randn (rng, M)
100
+ x, y = randn (rng, T, N), randn (rng, T , M)
101
+ ẋ, ẏ = randn (rng, T, N), randn (rng, T , M)
93
102
xy, ẋẏ = vcat (x, y), vcat (ẋ, ẏ)
94
103
ż_manual = _jvp (fdm, (xy)-> sum (sin, xy), xy, ẋẏ)[1 ]
95
104
ż_auto = jvp (fdm, x-> sum (sin, x[1 ]) + sum (sin, x[2 ]), ((x, y), (ẋ, ẏ)))
@@ -98,10 +107,10 @@ Base.length(x::DummyType) = size(x.X, 1)
98
107
@test ż_manual ≈ ż_multi
99
108
end
100
109
101
- @testset " j′vp" begin
110
+ @testset " j′vp(:: $T ) " for T in (Float64, ComplexF64)
102
111
rng, N, M, fdm = MersenneTwister (123456 ), 2 , 3 , central_fdm (5 , 1 )
103
- x, y = randn (rng, N), randn (rng, M)
104
- z̄ = randn (rng, N + M)
112
+ x, y = randn (rng, T, N), randn (rng, T , M)
113
+ z̄ = randn (rng, T, N + M)
105
114
xy = vcat (x, y)
106
115
x̄ȳ_manual = j′vp (fdm, xy-> sin .(xy), z̄, xy)
107
116
x̄ȳ_auto = j′vp (fdm, x-> sin .(vcat (x[1 ], x[2 ])), z̄, (x, y))
0 commit comments