Skip to content

Commit e19133f

Browse files
author
Wimmerer
committed
Fix boolean type infer, fix infer type stability
1 parent 8e45498 commit e19133f

File tree

10 files changed

+143
-42
lines changed

10 files changed

+143
-42
lines changed

src/SuiteSparseGraphBLAS.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ include("descriptors.jl")
3636
include("indexutils.jl")
3737

3838

39-
const GBVecOrMat = Union{GBVector, GBMatrix}
40-
const GBMatOrTranspose = Union{GBMatrix, Transpose{<:Any, <:GBMatrix}}
41-
const GBArray = Union{GBVector, GBMatOrTranspose}
39+
const GBVecOrMat{T} = Union{GBVector{T}, GBMatrix{T}}
40+
const GBMatOrTranspose{T} = Union{GBMatrix{T}, Transpose{<:Any, GBMatrix{T}}}
41+
const GBArray{T} = Union{GBVector{T}, GBMatOrTranspose{T}}
4242
const ptrtogbtype = Dict{Ptr, AbstractGBType}()
4343

4444
const GrBOp = Union{
@@ -125,4 +125,5 @@ function __init__()
125125
end
126126
end
127127

128+
include("operators/ztypes.jl")
128129
end #end of module

src/operations/ewise.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,7 @@ function emul(
8989
desc = nothing
9090
)
9191
op = _handlectx(op, ctxop, BinaryOps.TIMES)
92-
if op isa GrBOp
93-
t = ztype(op)
94-
else
95-
t = optype(u, v)
96-
end
92+
t = inferoutputtype(u, v, op)
9793
w = GBVector{t}(size(u))
9894
return emul!(w, u, v, op; mask , accum, desc)
9995
end
@@ -136,11 +132,7 @@ function emul(
136132
desc = nothing
137133
)
138134
op = _handlectx(op, ctxop, BinaryOps.TIMES)
139-
if op isa GrBOp
140-
t = ztype(op)
141-
else
142-
t = optype(A, B)
143-
end
135+
t = inferoutputtype(A, B, op)
144136
C = GBMatrix{t}(size(A))
145137
return emul!(C, A, B, op; mask, accum, desc)
146138
end
@@ -235,11 +227,7 @@ function eadd(
235227
desc = nothing
236228
)
237229
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
238-
if op isa GrBOp
239-
t = ztype(op)
240-
else
241-
t = optype(eltype(u), eltype(v))
242-
end
230+
t = inferoutputtype(u, v, op)
243231
w = GBVector{t}(size(u))
244232
return eadd!(w, u, v, op; mask, accum, desc)
245233
end
@@ -282,11 +270,7 @@ function eadd(
282270
desc = nothing
283271
)
284272
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
285-
if op isa GrBOp
286-
t = ztype(op)
287-
else
288-
t = optype(A, B)
289-
end
273+
t = inferoutputtype(A, B, op)
290274
C = GBMatrix{t}(size(A))
291275
return eadd!(C, A, B, op; mask, accum, desc)
292276
end

src/operations/kronecker.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@ function LinearAlgebra.kron(
5353
desc = nothing
5454
)
5555
op = _handlectx(op, ctxop, BinaryOps.TIMES)
56-
if op isa GrBOp
57-
t = ztype(op)
58-
else
59-
t = optype(A, B)
60-
end
56+
t = inferoutputtype(A, B, op)
6157
C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2))
6258
kron!(C, A, B, op; mask, accum, desc)
6359
return C

src/operations/map.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ function Base.map(
2525
op::UnaryUnion, A::GBArray;
2626
mask = nothing, accum = nothing, desc = nothing
2727
)
28-
return map!(op, similar(A), A; mask, accum, desc)
28+
t = inferoutputtype(A, op)
29+
return map!(op, similar(A, t), A; mask, accum, desc)
2930
end
3031

3132
function Base.map!(
@@ -54,7 +55,8 @@ function Base.map(
5455
op::BinaryUnion, x, A::GBArray;
5556
mask = nothing, accum = nothing, desc = nothing
5657
)
57-
return map!(op, similar(A), x, A; mask, accum, desc)
58+
t = inferoutputtype(A, op)
59+
return map!(op, similar(A, t), x, A; mask, accum, desc)
5860
end
5961

6062
function Base.map!(
@@ -83,7 +85,8 @@ function Base.map(
8385
op::BinaryUnion, A::GBArray, x;
8486
mask = nothing, accum = nothing, desc = nothing
8587
)
86-
return map!(op, similar(A), A, x; mask, accum, desc)
88+
t = inferoutputtype(A, op)
89+
return map!(op, similar(A, t), A, x; mask, accum, desc)
8790
end
8891

8992
function Base.broadcasted(::typeof(+), u::GBArray, x::valid_union;

src/operations/mul.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,7 @@ function mul(
9393
desc = nothing
9494
)
9595
op = _handlectx(op, ctxop, Semirings.PLUS_TIMES)
96-
if op isa libgb.GrB_Semiring
97-
t = ztype(op)
98-
else
99-
t = optype(A, B)
100-
end
96+
t = inferoutputtype(A, B, op)
10197
if A isa GBVector && B isa GBMatOrTranspose
10298
C = GBVector{t}(size(B, 2))
10399
elseif A isa GBMatOrTranspose && B isa GBVector

src/operations/operationutils.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,21 @@ function optype(atype, btype)
1616
end
1717
end
1818

19-
optype(A::GBArray, B::GBArray) = optype(eltype(A), eltype(B))
19+
optype(::GBArray{T}, ::GBArray{U}) where {T, U} = optype(T, U)
2020

21+
function inferoutputtype(A::GBArray{T}, B::GBArray{U}, op::AbstractOp) where {T, U}
22+
t = optype(A, B)
23+
return ztype(op, t)
24+
end
25+
function inferoutputtype(::GBArray{T}, op::AbstractOp) where {T}
26+
return ztype(op, T)
27+
end
28+
function inferoutputtype(::GBArray{T}, op) where {T}
29+
return ztype(op)
30+
end
31+
function inferoutputtype(::GBArray{T}, ::GBArray{U}, op) where {T, U}
32+
return ztype(op)
33+
end
2134
function _handlectx(ctx, ctxvar, default = nothing)
2235
if ctx === nothing || ctx === missing
2336
ctx2 = get(ctxvar)

src/operators/binaryops.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ end
379379

380380
Base.show(io::IO, ::MIME"text/plain", u::libgb.GrB_BinaryOp) = gxbprint(io, u)
381381

382-
xtype(op::BinaryUnion) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_xtype(op)])
383-
ytype(op::BinaryUnion) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ytype(op)])
384-
ztype(op::BinaryUnion) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ztype(op)])
382+
xtype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_xtype(op)])
383+
ytype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ytype(op)])
384+
ztype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ztype(op)])
385385

386386
"""
387387
First argument: `f(x::T,y::T)::T = x`

src/operators/operatorutils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ function getoperator(op, t)
2828

2929
if op isa AbstractOp
3030
return op[t]
31-
elseif op isa GrBOp
32-
return op
3331
else
3432
return op
3533
end

src/operators/ztypes.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
ztype(::AbstractOp, intype::DataType) = intype
2+
3+
#UnaryOps:
4+
ztype(::Types.ISINF_T, ::DataType) = Bool
5+
ztype(::Types.ISNAN_T, ::DataType) = Bool
6+
ztype(::Types.ISFINITE_T, ::DataType) = Bool
7+
8+
ztype(::Types.CONJ_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
9+
ztype(::Types.ABS_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
10+
ztype(::Types.CREAL_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
11+
ztype(::Types.CIMAG_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
12+
ztype(::Types.CARG_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
13+
14+
ztype(::Types.POSITIONI_T, ::DataType) = Int64
15+
ztype(::Types.POSITIONI1_T, ::DataType) = Int64
16+
ztype(::Types.POSITIONJ_T, ::DataType) = Int64
17+
ztype(::Types.POSITIONJ1_T, ::DataType) = Int64
18+
19+
#BinaryOps:
20+
ztype(::Types.EQ_T, ::DataType) = Bool
21+
ztype(::Types.NE_T, ::DataType) = Bool
22+
ztype(::Types.GT_T, ::DataType) = Bool
23+
ztype(::Types.LT_T, ::DataType) = Bool
24+
ztype(::Types.GE_T, ::DataType) = Bool
25+
ztype(::Types.LE_T, ::DataType) = Bool
26+
ztype(::Types.CMPLX_T, intype::Type{T}) where {T <: AbstractFloat} = Complex{T}
27+
28+
ztype(::Types.FIRSTI_T, ::DataType) = Int64
29+
ztype(::Types.FIRSTI1_T, ::DataType) = Int64
30+
ztype(::Types.FIRSTJ_T, ::DataType) = Int64
31+
ztype(::Types.FIRSTJ1_T, ::DataType) = Int64
32+
ztype(::Types.SECONDI_T, ::DataType) = Int64
33+
ztype(::Types.SECONDI1_T, ::DataType) = Int64
34+
ztype(::Types.SECONDJ_T, ::DataType) = Int64
35+
ztype(::Types.SECONDJ1_T, ::DataType) = Int64
36+
37+
#Semirings:
38+
ztype(::Types.LAND_EQ_T, ::DataType) = Bool
39+
ztype(::Types.LOR_EQ_T, ::DataType) = Bool
40+
ztype(::Types.LXOR_EQ_T, ::DataType) = Bool
41+
ztype(::Types.EQ_EQ_T, ::DataType) = Bool
42+
ztype(::Types.ANY_EQ_T, ::DataType) = Bool
43+
ztype(::Types.LAND_NE_T, ::DataType) = Bool
44+
ztype(::Types.LOR_NE_T, ::DataType) = Bool
45+
ztype(::Types.LXOR_NE_T, ::DataType) = Bool
46+
ztype(::Types.EQ_NE_T, ::DataType) = Bool
47+
ztype(::Types.ANY_NE_T, ::DataType) = Bool
48+
ztype(::Types.LAND_GT_T, ::DataType) = Bool
49+
ztype(::Types.LOR_GT_T, ::DataType) = Bool
50+
ztype(::Types.LXOR_GT_T, ::DataType) = Bool
51+
ztype(::Types.EQ_GT_T, ::DataType) = Bool
52+
ztype(::Types.ANY_GT_T, ::DataType) = Bool
53+
ztype(::Types.LAND_LT_T, ::DataType) = Bool
54+
ztype(::Types.LOR_LT_T, ::DataType) = Bool
55+
ztype(::Types.LXOR_LT_T, ::DataType) = Bool
56+
ztype(::Types.EQ_LT_T, ::DataType) = Bool
57+
ztype(::Types.ANY_LT_T, ::DataType) = Bool
58+
ztype(::Types.LAND_GE_T, ::DataType) = Bool
59+
ztype(::Types.LOR_GE_T, ::DataType) = Bool
60+
ztype(::Types.LXOR_GE_T, ::DataType) = Bool
61+
ztype(::Types.EQ_GE_T, ::DataType) = Bool
62+
ztype(::Types.ANY_GE_T, ::DataType) = Bool
63+
ztype(::Types.LAND_LE_T, ::DataType) = Bool
64+
ztype(::Types.LOR_LE_T, ::DataType) = Bool
65+
ztype(::Types.LXOR_LE_T, ::DataType) = Bool
66+
ztype(::Types.EQ_LE_T, ::DataType) = Bool
67+
ztype(::Types.ANY_LE_T, ::DataType) = Bool
68+
69+
70+
ztype(::Types.MIN_FIRSTI_T, ::DataType) = Int64
71+
ztype(::Types.MAX_FIRSTI_T, ::DataType) = Int64
72+
ztype(::Types.PLUS_FIRSTI_T, ::DataType) = Int64
73+
ztype(::Types.TIMES_FIRSTI_T, ::DataType) = Int64
74+
ztype(::Types.ANY_FIRSTI_T, ::DataType) = Int64
75+
ztype(::Types.MIN_FIRSTI1_T, ::DataType) = Int64
76+
ztype(::Types.MAX_FIRSTI1_T, ::DataType) = Int64
77+
ztype(::Types.PLUS_FIRSTI1_T, ::DataType) = Int64
78+
ztype(::Types.TIMES_FIRSTI1_T, ::DataType) = Int64
79+
ztype(::Types.ANY_FIRSTI1_T, ::DataType) = Int64
80+
ztype(::Types.MIN_FIRSTJ_T, ::DataType) = Int64
81+
ztype(::Types.MAX_FIRSTJ_T, ::DataType) = Int64
82+
ztype(::Types.PLUS_FIRSTJ_T, ::DataType) = Int64
83+
ztype(::Types.TIMES_FIRSTJ_T, ::DataType) = Int64
84+
ztype(::Types.ANY_FIRSTJ_T, ::DataType) = Int64
85+
ztype(::Types.MIN_FIRSTJ1_T, ::DataType) = Int64
86+
ztype(::Types.MAX_FIRSTJ1_T, ::DataType) = Int64
87+
ztype(::Types.PLUS_FIRSTJ1_T, ::DataType) = Int64
88+
ztype(::Types.TIMES_FIRSTJ1_T, ::DataType) = Int64
89+
ztype(::Types.ANY_FIRSTJ1_T, ::DataType) = Int64
90+
ztype(::Types.MIN_SECONDI_T, ::DataType) = Int64
91+
ztype(::Types.MAX_SECONDI_T, ::DataType) = Int64
92+
ztype(::Types.PLUS_SECONDI_T, ::DataType) = Int64
93+
ztype(::Types.TIMES_SECONDI_T, ::DataType) = Int64
94+
ztype(::Types.ANY_SECONDI_T, ::DataType) = Int64
95+
ztype(::Types.MIN_SECONDI1_T, ::DataType) = Int64
96+
ztype(::Types.MAX_SECONDI1_T, ::DataType) = Int64
97+
ztype(::Types.PLUS_SECONDI1_T, ::DataType) = Int64
98+
ztype(::Types.TIMES_SECONDI1_T, ::DataType) = Int64
99+
ztype(::Types.ANY_SECONDI1_T, ::DataType) = Int64
100+
ztype(::Types.MIN_SECONDJ_T, ::DataType) = Int64
101+
ztype(::Types.MAX_SECONDJ_T, ::DataType) = Int64
102+
ztype(::Types.PLUS_SECONDJ_T, ::DataType) = Int64
103+
ztype(::Types.TIMES_SECONDJ_T, ::DataType) = Int64
104+
ztype(::Types.ANY_SECONDJ_T, ::DataType) = Int64
105+
ztype(::Types.MIN_SECONDJ1_T, ::DataType) = Int64
106+
ztype(::Types.MAX_SECONDJ1_T, ::DataType) = Int64
107+
ztype(::Types.PLUS_SECONDJ1_T, ::DataType) = Int64
108+
ztype(::Types.TIMES_SECONDJ1_T, ::DataType) = Int64
109+
ztype(::Types.ANY_SECONDJ1_T, ::DataType) = Int64

test/operations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
@test emul(m, n, BinaryOps.POW)[3, 2] == m[3,2] ^ n[3,2]
1212
#check that the (*) op is being picked up from the semiring
1313
@test emul(m, n, Semirings.MAX_PLUS) == emul(m, n, BinaryOps.PLUS)
14+
@test eltype(m .== n) == Bool
1415
end
1516
@testset "kron" begin
1617
m1 = GBMatrix(UInt64[1, 2, 3, 5], UInt64[1, 3, 1, 2], Int8[1, 2, 3, 5])

0 commit comments

Comments
 (0)