Skip to content

Commit 66987ef

Browse files
author
Roger-luo
committed
update tests
1 parent 2420ea5 commit 66987ef

File tree

1 file changed

+46
-27
lines changed

1 file changed

+46
-27
lines changed

test/grad.jl

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,39 +46,58 @@ Base.length(x::DummyType) = size(x.X, 1)
4646

4747
@testset "multi vars jacobian/grad" begin
4848
rng, fdm = MersenneTwister(123456), central_fdm(5, 1)
49+
4950
f1(x, y) = x * y + x
50-
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
51-
jac_xs = jacobian(fdm, f1, x, y)
52-
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
53-
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
51+
f2(x, y) = sum(x * y + x)
52+
f3(x::Tuple) = sum(x[1]) + x[2]
53+
f4(d::Dict) = sum(d[:x]) + d[:y]
5454

55-
# mixed scalar and matrices
56-
x, y = rand(3, 3), 2
57-
jac_xs = jacobian(fdm, f1, x, y)
58-
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
59-
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
55+
@testset "jacobian" begin
56+
@testset "check multiple matrices" begin
57+
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
58+
jac_xs = jacobian(fdm, f1, x, y)
59+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
60+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
61+
end
6062

61-
f2(x, y) = sum(x * y + x)
62-
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
63-
dxs = grad(fdm, f2, x, y)
64-
@test dxs[1] grad(fdm, x->f2(x, y), x)
65-
@test dxs[2] grad(fdm, y->f2(x, y), y)
63+
@testset "check mixed scalar and matrices" begin
64+
x, y = rand(3, 3), 2
65+
jac_xs = jacobian(fdm, f1, x, y)
66+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
67+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
68+
end
69+
end
6670

67-
x, y = rand(rng, 3, 3), 2
68-
dxs = grad(fdm, f2, x, y)
69-
@test dxs[1] grad(fdm, x->f2(x, y), x)
70-
@test dxs[2] grad(fdm, y->f2(x, y), y)
71+
@testset "grad" begin
72+
@testset "check multiple matrices" begin
73+
x, y = rand(rng, 3, 3), rand(rng, 3, 3)
74+
dxs = grad(fdm, f2, x, y)
75+
@test dxs[1] grad(fdm, x->f2(x, y), x)
76+
@test dxs[2] grad(fdm, y->f2(x, y), y)
77+
end
7178

72-
f3(x::Tuple) = sum(x[1]) + x[2]
73-
dxs = grad(fdm, f3, (x, y))
74-
@test dxs[1] grad(fdm, x->f3((x, y)), x)
75-
@test dxs[2] grad(fdm, y->f3((x, y)), y)
79+
@testset "check mixed scalar & matrices" begin
80+
x, y = rand(rng, 3, 3), 2
81+
dxs = grad(fdm, f2, x, y)
82+
@test dxs[1] grad(fdm, x->f2(x, y), x)
83+
@test dxs[2] grad(fdm, y->f2(x, y), y)
84+
end
7685

77-
f4(d::Dict) = sum(d[:x]) + d[:y]
78-
d = Dict(:x=>x, :y=>y)
79-
dxs = grad(fdm, f4, d)
80-
@test dxs[:x] grad(fdm, x->f3((x, y)), x)
81-
@test dxs[:y] grad(fdm, y->f3((x, y)), y)
86+
@testset "check tuple" begin
87+
x, y = rand(rng, 3, 3), 2
88+
dxs = grad(fdm, f3, (x, y))
89+
@test dxs[1] grad(fdm, x->f3((x, y)), x)
90+
@test dxs[2] grad(fdm, y->f3((x, y)), y)
91+
end
92+
93+
@testset "check dict" begin
94+
x, y = rand(rng, 3, 3), 2
95+
d = Dict(:x=>x, :y=>y)
96+
dxs = grad(fdm, f4, d)
97+
@test dxs[:x] grad(fdm, x->f3((x, y)), x)
98+
@test dxs[:y] grad(fdm, y->f3((x, y)), y)
99+
end
100+
end
82101
end
83102

84103
function test_to_vec(x)

0 commit comments

Comments
 (0)