Skip to content

Commit f0dd5c9

Browse files
author
Will Kimmerer
authored
mul/ewise rules for basic arithmetic semiring (#26)
* arithmetic groundwork * arithmetic rules for mul and elwise 1st pass * tests and a few fixes * Add mask function, fix eadd(PLUS) * correct mul rrules * test folder structure * mask and vector transpose v1 * Broken constructor rules * arithmetic groundwork * arithmetic rules for mul and elwise 1st pass * tests and a few fixes * Add mask function, fix eadd(PLUS) * correct mul rrules * test folder structure * Broken constructor rules * Move out constructor rules for now * compat * rm constructorrule includes
1 parent 952e7a0 commit f0dd5c9

18 files changed

+305
-28
lines changed

Project.toml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ version = "0.4.0"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
810
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
11+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
912
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1013
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1114
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -15,8 +18,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1518
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1619

1720
[compat]
18-
SSGraphBLAS_jll = "5.1.2"
21+
CEnum = "0.4"
22+
ContextVariablesX = "0.1"
23+
MacroTools = "0.5"
24+
SSGraphBLAS_jll = "5.1"
1925
julia = "1.6"
20-
CEnum = "0.4.1"
21-
ContextVariablesX = "0.1.1"
22-
MacroTools = "0.5.6"
26+
ChainRulesCore = "0.10"
27+
ChainRulesTestUtils = "0.7"
28+
FiniteDifferences = "0.12"

src/SuiteSparseGraphBLAS.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,14 @@ include("operations/kronecker.jl")
8787
include("print.jl")
8888
include("import.jl")
8989
include("export.jl")
90-
91-
#EXPERIMENTAL
9290
include("options.jl")
91+
#EXPERIMENTAL
92+
include("chainrules/chainruleutils.jl")
93+
include("chainrules/mulrules.jl")
94+
include("chainrules/ewiserules.jl")
95+
include("chainrules/maprules.jl")
96+
include("chainrules/reducerules.jl")
97+
include("chainrules/selectrules.jl")
9398
#include("random.jl")
9499
include("misc.jl")
95100
export libgb

src/chainrules/chainruleutils.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import FiniteDifferences
2+
import LinearAlgebra
3+
import ChainRulesCore: frule, rrule
4+
using ChainRulesCore
5+
const RealOrComplex = Union{Real, Complex}
6+
7+
#Required for ChainRulesTestUtils
8+
function FiniteDifferences.to_vec(M::GBMatrix)
9+
I, J, X = findnz(M)
10+
function backtomat(xvec)
11+
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2))
12+
end
13+
return X, backtomat
14+
end
15+
16+
function FiniteDifferences.to_vec(v::GBVector)
17+
i, x = findnz(v)
18+
function backtovec(xvec)
19+
return GBVector(i, xvec; nrows=size(v, 1))
20+
end
21+
return x, backtovec
22+
end
23+
24+
function FiniteDifferences.rand_tangent(
25+
rng::AbstractRNG,
26+
x::GBMatrix{T}
27+
) where {T <: Union{AbstractFloat, Complex}}
28+
n = nnz(x)
29+
v = rand(rng, -9:0.01:9, n)
30+
I, J, _ = findnz(x)
31+
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2))
32+
end
33+
34+
function FiniteDifferences.rand_tangent(
35+
rng::AbstractRNG,
36+
x::GBVector{T}
37+
) where {T <: Union{AbstractFloat, Complex}}
38+
n = nnz(x)
39+
v = rand(rng, -9:0.01:9, n)
40+
I, _ = findnz(x)
41+
return GBVector(I, v; nrows = size(x, 1))
42+
end
43+
44+
FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent()
45+
# LinearAlgebra.norm freaks over the nothings.
46+
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)

src/chainrules/ewiserules.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#emul TIMES
2+
function frule(
3+
(_, ΔA, ΔB, _),
4+
::typeof(emul),
5+
A::GBArray,
6+
B::GBArray,
7+
::typeof(BinaryOps.TIMES)
8+
)
9+
Ω = emul(A, B, BinaryOps.TIMES)
10+
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES)
11+
return Ω, ∂Ω
12+
end
13+
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
14+
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES)
15+
end
16+
17+
function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES))
18+
function timespullback(ΔΩ)
19+
∂A = emul(ΔΩ, B)
20+
∂B = emul(ΔΩ, A)
21+
return NoTangent(), ∂A, ∂B, NoTangent()
22+
end
23+
return emul(A, B, BinaryOps.TIMES), timespullback
24+
end
25+
26+
function rrule(::typeof(emul), A::GBArray, B::GBArray)
27+
Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES)
28+
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
29+
return Ω, emulpb
30+
end
31+
32+
############
33+
# eadd rules
34+
############
35+
36+
# PLUS
37+
######
38+
39+
function frule(
40+
(_, ΔA, ΔB, _),
41+
::typeof(eadd),
42+
A::GBArray,
43+
B::GBArray,
44+
::typeof(BinaryOps.PLUS)
45+
)
46+
Ω = eadd(A, B, BinaryOps.PLUS)
47+
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS)
48+
return Ω, ∂Ω
49+
end
50+
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
51+
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS)
52+
end
53+
54+
function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS))
55+
function pluspullback(ΔΩ)
56+
return (
57+
NoTangent(),
58+
mask(ΔΩ, A; structural = true),
59+
mask(ΔΩ, B; structural = true),
60+
NoTangent()
61+
)
62+
end
63+
return eadd(A, B, BinaryOps.PLUS), pluspullback
64+
end
65+
66+
# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule.
67+
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
68+
Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS)
69+
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
70+
return Ω, eaddpb
71+
end

src/chainrules/maprules.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
2+
# than AbstractOp.
3+
#function rrule(map, f, xs)
4+
# # Rather than 3 maps really want 1 multimap
5+
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
6+
# ys = map(first, ys_and_pullbacks)
7+
# pullbacks = map(last, ys_and_pullbacks)
8+
# function map_pullback(dys)
9+
# _call(f, x) = f(x)
10+
# dfs_and_dxs = map(_call, pullbacks, dys)
11+
# # but in your case you know it will be NoTangent() so can skip
12+
# df = sum(first, dfs_and_dxs)
13+
# dxs = map(last, dfs_and_dxs)
14+
# return NoTangent(), df, dxs
15+
# end
16+
# return ys, map_pullback
17+
#end

src/chainrules/mulrules.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Standard arithmetic mul:
2+
function frule(
3+
(_, ΔA, ΔB),
4+
::typeof(mul),
5+
A::GBMatOrTranspose,
6+
B::GBMatOrTranspose
7+
)
8+
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
9+
end
10+
function frule(
11+
(_, ΔA, ΔB, _),
12+
::typeof(mul),
13+
A::GBMatOrTranspose,
14+
B::GBMatOrTranspose,
15+
::typeof(Semirings.PLUS_TIMES)
16+
)
17+
Ω = mul(A, B)
18+
∂Ω = mul(ΔA, B) + mul(A, ΔB)
19+
return Ω, ∂Ω
20+
end
21+
# Tests will not pass for this. For two reasons.
22+
# First is #25, the output inference is not type stable.
23+
# That's it's own issue.
24+
25+
# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
26+
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.
27+
28+
function rrule(
29+
::typeof(mul),
30+
A::GBMatOrTranspose,
31+
B::GBMatOrTranspose,
32+
::typeof(Semirings.PLUS_TIMES)
33+
)
34+
function mulpullback(ΔΩ)
35+
∂A = mul(ΔΩ, B'; mask=A)
36+
∂B = mul(A', ΔΩ; mask=B)
37+
return NoTangent(), ∂A, ∂B, NoTangent()
38+
end
39+
return mul(A, B), mulpullback
40+
end
41+
42+
43+
function rrule(
44+
::typeof(mul),
45+
A::GBMatOrTranspose,
46+
B::GBMatOrTranspose
47+
)
48+
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES)
49+
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3]
50+
return Ω, pullback
51+
end

src/chainrules/reducerules.jl

Whitespace-only changes.

src/chainrules/selectrules.jl

Whitespace-only changes.

src/lib/LibGraphBLAS.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,27 @@ macro wraperror(code)
2727
elseif info == GrB_NO_VALUE
2828
return nothing
2929
else
30-
if info == GrB_UNINITIALIZED_OBJECT
30+
if info == GrB_UNINITIALIZED_OBJECT
3131
throw(UninitializedObjectError)
32-
elseif info == GrB_INVALID_OBJECT
32+
elseif info == GrB_INVALID_OBJECT
3333
throw(InvalidObjectError)
34-
elseif info == GrB_NULL_POINTER
34+
elseif info == GrB_NULL_POINTER
3535
throw(NullPointerError)
36-
elseif info == GrB_INVALID_VALUE
36+
elseif info == GrB_INVALID_VALUE
3737
throw(InvalidValueError)
38-
elseif info == GrB_INVALID_INDEX
38+
elseif info == GrB_INVALID_INDEX
3939
throw(InvalidIndexError)
40-
elseif info == GrB_DOMAIN_MISMATCH
40+
elseif info == GrB_DOMAIN_MISMATCH
4141
throw(DomainError(nothing, "GraphBLAS Domain Mismatch"))
4242
elseif info == GrB_DIMENSION_MISMATCH
4343
throw(DimensionMismatch())
44-
elseif info == GrB_OUTPUT_NOT_EMPTY
44+
elseif info == GrB_OUTPUT_NOT_EMPTY
4545
throw(OutputNotEmptyError)
46-
elseif info == GrB_OUT_OF_MEMORY
46+
elseif info == GrB_OUT_OF_MEMORY
4747
throw(OutOfMemoryError())
48-
elseif info == GrB_INSUFFICIENT_SPACE
48+
elseif info == GrB_INSUFFICIENT_SPACE
4949
throw(InsufficientSpaceError)
50-
elseif info == GrB_INDEX_OUT_OF_BOUNDS
50+
elseif info == GrB_INDEX_OUT_OF_BOUNDS
5151
throw(BoundsError())
5252
elseif info == GrB_PANIC
5353
throw(PANIC)
@@ -843,7 +843,7 @@ for T ∈ valid_vec
843843
nvals = GrB_Vector_nvals(v)
844844
I = Vector{GrB_Index}(undef, nvals)
845845
X = Vector{$type}(undef, nvals)
846-
nvals = Ref{GrB_Index}()
846+
nvals = Ref{GrB_Index}(nvals)
847847
$func(I, X, nvals, v)
848848
nvals[] == length(I) == length(X) || throw(DimensionMismatch())
849849
return I .+ 1, X

src/matrix.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = X[k]. The dup funct
1818
to `|` for booleans and `+` for nonbooleans.
1919
"""
2020
function GBMatrix(
21-
I::Vector, J::Vector, X::Vector{T};
21+
I::AbstractVector, J::AbstractVector, X::AbstractVector{T};
2222
dup = BinaryOps.PLUS, nrows = maximum(I), ncols = maximum(J)
2323
) where {T}
2424
A = GBMatrix{T}(nrows, ncols)
@@ -33,14 +33,14 @@ Create an nrows x ncols GBMatrix M such that M[I[k], J[k]] = x.
3333
The resulting matrix is "iso-valued" such that it only stores `x` once rather than once for
3434
each index.
3535
"""
36-
function GBMatrix(I::Vector, J::Vector, x::T;
36+
function GBMatrix(I::AbstractVector, J::AbstractVector, x::T;
3737
nrows = maximum(I), ncols = maximum(J)) where {T}
3838
A = GBMatrix{T}(nrows, ncols)
3939
build(A, I, J, x)
4040
return A
4141
end
4242

43-
function build(A::GBMatrix{T}, I::Vector, J::Vector, x::T) where {T}
43+
function build(A::GBMatrix{T}, I::AbstractVector, J::AbstractVector, x::T) where {T}
4444
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build matrix with existing elements"))
4545
length(I) == length(J) || DimensionMismatch("I, J and X must have the same length")
4646
x = GBScalar(x)
@@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix)
158158
gxbprint(io, A)
159159
end
160160

161-
SparseArrays.nonzeros(A::GBArray) = findnz(A)[3]
161+
SparseArrays.nonzeros(A::GBArray) = findnz(A)[end]
162+
162163

163164
# Indexing functions
164165
####################

src/operations/ewise.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ function emul!(
6161
desc = nothing
6262
)
6363
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
64-
6564
size(w) == size(u) == size(v) || throw(DimensionMismatch())
6665
op = getoperator(op, optype(u, v))
6766
accum = getoperator(accum, eltype(w))
@@ -275,6 +274,13 @@ function eadd(
275274
return eadd!(C, A, B, op; mask, accum, desc)
276275
end
277276

277+
function Base.:+(A::GBArray, B::GBArray)
278+
eadd(A, B, nothing)
279+
end
280+
281+
function Base.:-(A::GBArray, B::GBArray)
282+
eadd(A, B, BinaryOps.MINUS)
283+
end
278284
#Elementwise Broadcasts
279285
#######################
280286

src/operations/mul.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ function LinearAlgebra.mul!(
5959
return w
6060
end
6161

62-
6362
"""
6463
mul(A::GBArray, B::GBArray; kwargs...)::GBArray
6564

src/operations/transpose.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ end
6868
mask!(C::GBArray, A::GBArray, mask::GBArray)
6969
7070
Apply a mask to matrix `A`, storing the results in C.
71+
7172
"""
7273
function mask!(C::GBArray, A::GBArray, mask::GBArray; structural = false, complement = false)
7374
desc = Descriptors.T0

src/vector.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...)
1414
1515
Create a GBVector from a vector of indices `I` and a vector of values `X`.
1616
"""
17-
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T}
18-
x = GBVector{T}(maximum(I))
17+
function GBVector(I::AbstractVector, X::AbstractVector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T}
18+
x = GBVector{T}(nrows)
1919
build(x, I, X, dup = dup)
2020
return x
2121
end
@@ -27,14 +27,14 @@ Create an nrows length GBVector v such that M[I[k]] = x.
2727
The resulting vector is "iso-valued" such that it only stores `x` once rather than once for
2828
each index.
2929
"""
30-
function GBVector(I::Vector, x::T;
30+
function GBVector(I::AbstractVector, x::T;
3131
nrows = maximum(I)) where {T}
3232
A = GBVector{T}(nrows)
3333
build(A, I, x)
3434
return A
3535
end
3636

37-
function build(A::GBVector{T}, I::Vector, x::T) where {T}
37+
function build(A::GBVector{T}, I::AbstractVector, x::T) where {T}
3838
nnz(A) == 0 || throw(libgb.OutputNotEmptyError("Cannot build vector with existing elements"))
3939
x = GBScalar(x)
4040

test/chainrules/chainrulesutils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using FiniteDifferences
2+
function test_to_vec(x::T; check_inferred=true) where {T}
3+
check_inferred && @inferred FiniteDifferences.to_vec(x)
4+
x_vec, back = FiniteDifferences.to_vec(x)
5+
@test x_vec isa Vector
6+
@test all(s -> s isa Real, x_vec)
7+
check_inferred && @inferred back(x_vec)
8+
@test x == back(x_vec)
9+
return nothing
10+
end
11+
12+
@testset "chainrulesutils" begin
13+
y = GBMatrix(sprand(10, 10, 0.5))
14+
test_to_vec(y)
15+
v = GBVector(sprand(10, 0.5))
16+
test_to_vec(v)
17+
end

0 commit comments

Comments
 (0)