Skip to content

Commit bc2cf6d

Browse files
committed
start working on spdiag rrule
1 parent 11c230c commit bc2cf6d

File tree

4 files changed

+73
-30
lines changed

4 files changed

+73
-30
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.49.0"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,33 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
4949

5050
return (I, V), findnz_pullback
5151
end
52+
53+
function _spdiagm_back(p, ȳ)
54+
k, v = p
55+
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
56+
return Tangent{typeof(p)}(second = d)
57+
end
58+
59+
function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)
60+
function diagm_pullback(Δ)
61+
_, ȳ = unthunk(Δ)
62+
return (NoTangent(), NoTangent(), NoTangent(), _spdiagm_back.(kv, Ref(ȳ))...)
63+
end
64+
return spdiagm(m, n, kv...), diagm_pullback
65+
end
66+
67+
function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
68+
function diagm_pullback(Δ)
69+
_, ȳ = unthunk(Δ)
70+
return (NoTangent(), _spdiagm_back.(kv, Ref(ȳ))...)
71+
end
72+
return spdiagm(kv...), diagm_pullback
73+
end
74+
75+
function rrule(::typeof(spdiagm), v::AbstractVector)
76+
function diagm_pullback(Δ)
77+
_, ȳ = unthunk(Δ)
78+
return (NoTangent(), diag(ȳ))
79+
end
80+
return spdiagm(v), diagm_pullback
81+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ end
1818
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
1919
end
2020

21+
@testset "spdiagm" begin
22+
@test 1 == 1
23+
m = 5
24+
n = 4
25+
v1 = ones(m)
26+
v2 = ones(n)
27+
test_rrule(spdiagm, m, n, 0 => v2)
28+
29+
# test_rrule(spdiagm, 0 => v1)
30+
# test_rrule(spdiagm, v1)
31+
end
32+
2133
@testset "findnz" begin
2234
A = sprand(5, 5, 0.5)
2335
dA = similar(A)

test/runtests.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,41 +52,41 @@ end
5252

5353
test_method_tables() # Check the global method tables are consistent
5454

55-
# Each file puts all tests inside one or more @testset blocks
56-
include_test("rulesets/Base/base.jl")
57-
include_test("rulesets/Base/fastmath_able.jl")
58-
include_test("rulesets/Base/evalpoly.jl")
59-
include_test("rulesets/Base/array.jl")
60-
include_test("rulesets/Base/arraymath.jl")
61-
include_test("rulesets/Base/indexing.jl")
62-
include_test("rulesets/Base/mapreduce.jl")
63-
include_test("rulesets/Base/sort.jl")
64-
include_test("rulesets/Base/broadcast.jl")
65-
66-
include_test("unzipped.jl") # used primarily for broadcast
55+
# # Each file puts all tests inside one or more @testset blocks
56+
# include_test("rulesets/Base/base.jl")
57+
# include_test("rulesets/Base/fastmath_able.jl")
58+
# include_test("rulesets/Base/evalpoly.jl")
59+
# include_test("rulesets/Base/array.jl")
60+
# include_test("rulesets/Base/arraymath.jl")
61+
# include_test("rulesets/Base/indexing.jl")
62+
# include_test("rulesets/Base/mapreduce.jl")
63+
# include_test("rulesets/Base/sort.jl")
64+
# include_test("rulesets/Base/broadcast.jl")
65+
66+
# include_test("unzipped.jl") # used primarily for broadcast
67+
68+
# println()
69+
70+
# include_test("rulesets/Statistics/statistics.jl")
71+
72+
# println()
73+
74+
# include_test("rulesets/LinearAlgebra/dense.jl")
75+
# include_test("rulesets/LinearAlgebra/norm.jl")
76+
# include_test("rulesets/LinearAlgebra/matfun.jl")
77+
# include_test("rulesets/LinearAlgebra/structured.jl")
78+
# include_test("rulesets/LinearAlgebra/symmetric.jl")
79+
# include_test("rulesets/LinearAlgebra/factorization.jl")
80+
# include_test("rulesets/LinearAlgebra/blas.jl")
81+
# include_test("rulesets/LinearAlgebra/lapack.jl")
82+
# include_test("rulesets/LinearAlgebra/uniformscaling.jl")
6783

6884
println()
6985

70-
include_test("rulesets/Statistics/statistics.jl")
86+
include("rulesets/SparseArrays/sparsematrix.jl")
7187

7288
println()
7389

74-
include_test("rulesets/LinearAlgebra/dense.jl")
75-
include_test("rulesets/LinearAlgebra/norm.jl")
76-
include_test("rulesets/LinearAlgebra/matfun.jl")
77-
include_test("rulesets/LinearAlgebra/structured.jl")
78-
include_test("rulesets/LinearAlgebra/symmetric.jl")
79-
include_test("rulesets/LinearAlgebra/factorization.jl")
80-
include_test("rulesets/LinearAlgebra/blas.jl")
81-
include_test("rulesets/LinearAlgebra/lapack.jl")
82-
include_test("rulesets/LinearAlgebra/uniformscaling.jl")
83-
84-
println()
85-
86-
include_test("rulesets/SparseArrays/sparsematrix.jl")
87-
88-
println()
89-
90-
include_test("rulesets/Random/random.jl")
90+
# include_test("rulesets/Random/random.jl")
9191
println()
9292
end

0 commit comments

Comments
 (0)