Skip to content

Commit b1daa7a

Browse files
authored
Merge pull request #590 from CarloLucibello/cl/findnz2
rrule for findnz
2 parents aab91c6 + 5d5741c commit b1daa7a

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.26.1"
3+
version = "1.27.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,42 @@ function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::Abstra
1010
return sparse(I, J, V, m, n, combine), sparse_pullback
1111
end
1212

13-
function rrule(::Type{T}, A::AbstractMatrix) where T <: SparseMatrixCSC
13+
function rrule(::Type{T}, A::AbstractMatrix) where T <: AbstractSparseMatrix
1414
function sparse_pullback(Ω̄)
1515
return NoTangent(), Ω̄
1616
end
1717
return T(A), sparse_pullback
1818
end
1919

20-
function rrule(::Type{T}, v::AbstractVector) where T <: SparseVector
20+
function rrule(::Type{T}, v::AbstractVector) where T <: AbstractSparseVector
2121
function sparse_pullback(Ω̄)
2222
return NoTangent(), Ω̄
2323
end
2424
return T(v), sparse_pullback
2525
end
26+
27+
function rrule(::typeof(findnz), A::AbstractSparseMatrix)
28+
I, J, V = findnz(A)
29+
m, n = size(A)
30+
31+
function findnz_pullback(Δ)
32+
_, _, V̄ = unthunk(Δ)
33+
isa AbstractZero && return (NoTangent(), V̄)
34+
return NoTangent(), sparse(I, J, V̄, m, n)
35+
end
36+
37+
return (I, J, V), findnz_pullback
38+
end
39+
40+
function rrule(::typeof(findnz), v::AbstractSparseVector)
41+
I, V = findnz(v)
42+
n = length(v)
43+
44+
function findnz_pullback(Δ)
45+
_, V̄ = unthunk(Δ)
46+
isa AbstractZero && return (NoTangent(), V̄)
47+
return NoTangent(), sparsevec(I, V̄, n)
48+
end
49+
50+
return (I, V), findnz_pullback
51+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,27 @@ end
99
@testset "SparseMatrixCSC(A)" begin
1010
A = rand(5, 3)
1111
test_rrule(SparseMatrixCSC, A)
12-
test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-5)
12+
test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-4)
1313
end
1414

1515
@testset "SparseVector(v)" begin
1616
v = rand(5)
1717
test_rrule(SparseVector, v)
18-
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-5)
18+
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
19+
end
20+
21+
@testset "findnz" begin
22+
A = sprand(5, 5, 0.5)
23+
dA = similar(A)
24+
rand!(dA.nzval)
25+
I, J, V = findnz(A)
26+
= rand!(similar(V))
27+
test_rrule(findnz, A dA, output_tangent=(zeros(length(I)), zeros(length(J)), V̄))
28+
29+
v = sprand(5, 0.5)
30+
dv = similar(v)
31+
rand!(dv.nzval)
32+
I, V = findnz(v)
33+
= rand!(similar(V))
34+
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
1935
end

0 commit comments

Comments
 (0)