18
18
test_rrule (SparseVector{Float32}, Float32 .(v), rtol= 1e-4 )
19
19
end
20
20
21
+ # copied over from test/rulesets/LinearAlgebra/structured
21
22
@testset " spdiagm" begin
22
- @test 1 == 1
23
- m = 5
24
- n = 4
25
- v1 = ones (m)
26
- v2 = ones (n)
27
- test_rrule (spdiagm, m, n, 0 => v2)
28
-
29
- # test_rrule(spdiagm, 0 => v1)
30
- # test_rrule(spdiagm, v1)
23
+ @testset " without size" begin
24
+ M, N = 7 , 9
25
+ s = (8 , 8 )
26
+ a, ā = randn (M), randn (M)
27
+ b, b̄ = randn (M), randn (M)
28
+ c, c̄ = randn (M - 1 ), randn (M - 1 )
29
+ ȳ = randn (s)
30
+ ps = (0 => a, 1 => b, 0 => c)
31
+ y, back = rrule (spdiagm, ps... )
32
+ @test y == spdiagm (ps... )
33
+ ∂self, ∂pa, ∂pb, ∂pc = back (ȳ)
34
+ @test ∂self === NoTangent ()
35
+ ∂a_fd, ∂b_fd, ∂c_fd = j′vp (_fdm, (a, b, c) -> spdiagm (0 => a, 1 => b, 0 => c), ȳ, a, b, c)
36
+ for (p, ∂px, ∂x_fd) in zip (ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
37
+ ∂px = unthunk (∂px)
38
+ @test ∂px isa Tangent{typeof (p)}
39
+ @test ∂px. first isa AbstractZero
40
+ @test ∂px. second ≈ ∂x_fd
41
+ end
42
+ end
43
+ @testset " with size" begin
44
+ M, N = 7 , 9
45
+ a, ā = randn (M), randn (M)
46
+ b, b̄ = randn (M), randn (M)
47
+ c, c̄ = randn (M - 1 ), randn (M - 1 )
48
+ ȳ = randn (M, N)
49
+ ps = (0 => a, 1 => b, 0 => c)
50
+ y, back = rrule (spdiagm, M, N, ps... )
51
+ @test y == spdiagm (M, N, ps... )
52
+ ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back (ȳ)
53
+ @test ∂self === NoTangent ()
54
+ @test ∂M === NoTangent ()
55
+ @test ∂N === NoTangent ()
56
+ ∂a_fd, ∂b_fd, ∂c_fd = j′vp (_fdm, (a, b, c) -> spdiagm (M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)
57
+ for (p, ∂px, ∂x_fd) in zip (ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
58
+ ∂px = unthunk (∂px)
59
+ @test ∂px isa Tangent{typeof (p)}
60
+ @test ∂px. first isa AbstractZero
61
+ @test ∂px. second ≈ ∂x_fd
62
+ end
63
+ end
31
64
end
32
65
33
66
@testset " findnz" begin
54
87
test_rrule (logabsdet, A)
55
88
test_rrule (logdet, A)
56
89
test_rrule (det, A)
57
- end
90
+ end
0 commit comments