@@ -20,14 +20,14 @@ Base.length(x::DummyType) = size(x.X, 1)
20
20
rng, fdm = MersenneTwister (123456 ), central_fdm (5 , 1 )
21
21
x = randn (rng, T, 2 )
22
22
xc = copy (x)
23
- @test grad (fdm, x-> sin (x[1 ]) + cos (x[2 ]), x) ≈ [cos (x[1 ]), - sin (x[2 ])]
23
+ @test grad (fdm, x-> sin (x[1 ]) + cos (x[2 ]), x)[ 1 ] ≈ [cos (x[1 ]), - sin (x[2 ])]
24
24
@test xc == x
25
25
end
26
26
27
27
function check_jac_and_jvp_and_j′vp (fdm, f, ȳ, x, ẋ, J_exact)
28
28
xc = copy (x)
29
- @test jacobian (fdm, f, x; len= length (ȳ)) ≈ J_exact
30
- @test jacobian (fdm, f, x) == jacobian (fdm, f, x; len= length (ȳ))
29
+ @test jacobian (fdm, f, x; len= length (ȳ))[ 1 ] ≈ J_exact
30
+ @test jacobian (fdm, f, x)[ 1 ] == jacobian (fdm, f, x; len= length (ȳ))[ 1 ]
31
31
@test _jvp (fdm, f, x, ẋ) ≈ J_exact * ẋ
32
32
@test _j′vp (fdm, f, ȳ, x) ≈ transpose (J_exact) * ȳ
33
33
@test xc == x
@@ -56,46 +56,46 @@ Base.length(x::DummyType) = size(x.X, 1)
56
56
@testset " check multiple matrices" begin
57
57
x, y = rand (rng, 3 , 3 ), rand (rng, 3 , 3 )
58
58
jac_xs = jacobian (fdm, f1, x, y)
59
- @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)
60
- @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)
59
+ @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)[ 1 ]
60
+ @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)[ 1 ]
61
61
end
62
62
63
63
@testset " check mixed scalar and matrices" begin
64
64
x, y = rand (3 , 3 ), 2
65
65
jac_xs = jacobian (fdm, f1, x, y)
66
- @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)
67
- @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)
66
+ @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)[ 1 ]
67
+ @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)[ 1 ]
68
68
end
69
69
end
70
70
71
71
@testset " grad" begin
72
72
@testset " check multiple matrices" begin
73
73
x, y = rand (rng, 3 , 3 ), rand (rng, 3 , 3 )
74
74
dxs = grad (fdm, f2, x, y)
75
- @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)
76
- @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)
75
+ @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)[ 1 ]
76
+ @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)[ 1 ]
77
77
end
78
78
79
79
@testset " check mixed scalar & matrices" begin
80
80
x, y = rand (rng, 3 , 3 ), 2
81
81
dxs = grad (fdm, f2, x, y)
82
- @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)
83
- @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)
82
+ @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)[ 1 ]
83
+ @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)[ 1 ]
84
84
end
85
85
86
86
@testset " check tuple" begin
87
87
x, y = rand (rng, 3 , 3 ), 2
88
- dxs = grad (fdm, f3, (x, y))
89
- @test dxs[1 ] ≈ grad (fdm, x-> f3 ((x, y)), x)
90
- @test dxs[2 ] ≈ grad (fdm, y-> f3 ((x, y)), y)
88
+ dxs = grad (fdm, f3, (x, y))[ 1 ]
89
+ @test dxs[1 ] ≈ grad (fdm, x-> f3 ((x, y)), x)[ 1 ]
90
+ @test dxs[2 ] ≈ grad (fdm, y-> f3 ((x, y)), y)[ 1 ]
91
91
end
92
92
93
93
@testset " check dict" begin
94
94
x, y = rand (rng, 3 , 3 ), 2
95
95
d = Dict (:x => x, :y => y)
96
- dxs = grad (fdm, f4, d)
97
- @test dxs[:x ] ≈ grad (fdm, x-> f3 ((x, y)), x)
98
- @test dxs[:y ] ≈ grad (fdm, y-> f3 ((x, y)), y)
96
+ dxs = grad (fdm, f4, d)[ 1 ]
97
+ @test dxs[:x ] ≈ grad (fdm, x-> f3 ((x, y)), x)[ 1 ]
98
+ @test dxs[:y ] ≈ grad (fdm, y-> f3 ((x, y)), y)[ 1 ]
99
99
end
100
100
end
101
101
end
@@ -168,8 +168,8 @@ Base.length(x::DummyType) = size(x.X, 1)
168
168
x, y = randn (rng, T, N), randn (rng, T, M)
169
169
z̄ = randn (rng, T, N + M)
170
170
xy = vcat (x, y)
171
- x̄ȳ_manual = j′vp (fdm, xy-> sin .(xy), z̄, xy)
172
- x̄ȳ_auto = j′vp (fdm, x-> sin .(vcat (x[1 ], x[2 ])), z̄, (x, y))
171
+ x̄ȳ_manual = j′vp (fdm, xy-> sin .(xy), z̄, xy)[ 1 ]
172
+ x̄ȳ_auto = j′vp (fdm, x-> sin .(vcat (x[1 ], x[2 ])), z̄, (x, y))[ 1 ]
173
173
x̄ȳ_multi = j′vp (fdm, (x, y)-> sin .(vcat (x, y)), z̄, x, y)
174
174
@test x̄ȳ_manual ≈ vcat (x̄ȳ_auto... )
175
175
@test x̄ȳ_manual ≈ vcat (x̄ȳ_multi... )
0 commit comments