Skip to content

Commit 63de965

Browse files
authored
Merge pull request #2673 from JuliaGPU/ksh/sparse_types
Remove some invalid conversions and test more
2 parents 8b93480 + fade351 commit 63de965

File tree

2 files changed

+71
-32
lines changed

2 files changed

+71
-32
lines changed

lib/cusparse/types.jl

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
## index type
44

55
function Base.convert(::Type{cusparseIndexType_t}, T::DataType)
6-
if T == Int16
7-
return CUSPARSE_INDEX_16U
8-
elseif T == Int32
6+
if T == Int32
97
return CUSPARSE_INDEX_32I
108
elseif T == Int64
119
return CUSPARSE_INDEX_64I
@@ -15,9 +13,7 @@ function Base.convert(::Type{cusparseIndexType_t}, T::DataType)
1513
end
1614

1715
function Base.convert(::Type{Type}, T::cusparseIndexType_t)
18-
if T == CUSPARSE_INDEX_16U
19-
return Int16
20-
elseif T == CUSPARSE_INDEX_32I
16+
if T == CUSPARSE_INDEX_32I
2117
return Int32
2218
elseif T == CUSPARSE_INDEX_64I
2319
return Int64
@@ -35,7 +31,7 @@ function Base.convert(::Type{cusparseIndexBase_t}, base::Integer)
3531
elseif base == 1
3632
return CUSPARSE_INDEX_BASE_ONE
3733
else
38-
throw(ArgumentError("CUSPARSE does not support index base $base!"))
34+
throw(ArgumentError("CUSPARSE does not support index base $(base)!"))
3935
end
4036
end
4137

@@ -45,7 +41,7 @@ function Base.convert(T::Type{<:Integer}, base::cusparseIndexBase_t)
4541
elseif base == CUSPARSE_INDEX_BASE_ONE
4642
return T(1)
4743
else
48-
throw(ArgumentError("Unknown index base $base!"))
44+
throw(ArgumentError("Unknown index base $(base)!"))
4945
end
5046
end
5147

@@ -54,105 +50,105 @@ end
5450

5551
function Base.convert(::Type{cusparseOperation_t}, trans::SparseChar)
5652
if trans == 'N'
57-
CUSPARSE_OPERATION_NON_TRANSPOSE
53+
return CUSPARSE_OPERATION_NON_TRANSPOSE
5854
elseif trans == 'T'
59-
CUSPARSE_OPERATION_TRANSPOSE
55+
return CUSPARSE_OPERATION_TRANSPOSE
6056
elseif trans == 'C'
61-
CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
57+
return CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
6258
else
6359
throw(ArgumentError("Unknown operation $trans"))
6460
end
6561
end
6662

6763
function Base.convert(::Type{cusparseMatrixType_t}, mattype::SparseChar)
6864
if mattype == 'G'
69-
CUSPARSE_MATRIX_TYPE_GENERAL
65+
return CUSPARSE_MATRIX_TYPE_GENERAL
7066
elseif mattype == 'T'
71-
CUSPARSE_MATRIX_TYPE_TRIANGULAR
67+
return CUSPARSE_MATRIX_TYPE_TRIANGULAR
7268
elseif mattype == 'S'
73-
CUSPARSE_MATRIX_TYPE_SYMMETRIC
69+
return CUSPARSE_MATRIX_TYPE_SYMMETRIC
7470
elseif mattype == 'H'
75-
CUSPARSE_MATRIX_TYPE_HERMITIAN
71+
return CUSPARSE_MATRIX_TYPE_HERMITIAN
7672
else
7773
throw(ArgumentError("Unknown matrix type $mattype"))
7874
end
7975
end
8076

8177
function Base.convert(::Type{cusparseSpMatAttribute_t}, attribute::SparseChar)
8278
if attribute == 'F'
83-
CUSPARSE_SPMAT_FILL_MODE
79+
return CUSPARSE_SPMAT_FILL_MODE
8480
elseif attribute == 'D'
85-
CUSPARSE_SPMAT_DIAG_TYPE
81+
return CUSPARSE_SPMAT_DIAG_TYPE
8682
else
8783
throw(ArgumentError("Unknown attribute $attribute"))
8884
end
8985
end
9086

9187
function Base.convert(::Type{cusparseFillMode_t}, uplo::SparseChar)
9288
if uplo == 'U'
93-
CUSPARSE_FILL_MODE_UPPER
89+
return CUSPARSE_FILL_MODE_UPPER
9490
elseif uplo == 'L'
95-
CUSPARSE_FILL_MODE_LOWER
91+
return CUSPARSE_FILL_MODE_LOWER
9692
else
9793
throw(ArgumentError("Unknown fill mode $uplo"))
9894
end
9995
end
10096

10197
function Base.convert(::Type{cusparseDiagType_t}, diag::SparseChar)
10298
if diag == 'U'
103-
CUSPARSE_DIAG_TYPE_UNIT
99+
return CUSPARSE_DIAG_TYPE_UNIT
104100
elseif diag == 'N'
105-
CUSPARSE_DIAG_TYPE_NON_UNIT
101+
return CUSPARSE_DIAG_TYPE_NON_UNIT
106102
else
107103
throw(ArgumentError("Unknown diag type $diag"))
108104
end
109105
end
110106

111107
function Base.convert(::Type{cusparseIndexBase_t}, index::SparseChar)
112108
if index == 'Z'
113-
CUSPARSE_INDEX_BASE_ZERO
109+
return CUSPARSE_INDEX_BASE_ZERO
114110
elseif index == 'O'
115-
CUSPARSE_INDEX_BASE_ONE
111+
return CUSPARSE_INDEX_BASE_ONE
116112
else
117-
throw(ArgumentError("Unknown index base"))
113+
throw(ArgumentError("Unknown index base $index"))
118114
end
119115
end
120116

121117
function Base.convert(::Type{cusparseDirection_t}, dir::SparseChar)
122118
if dir == 'R'
123-
CUSPARSE_DIRECTION_ROW
119+
return CUSPARSE_DIRECTION_ROW
124120
elseif dir == 'C'
125-
CUSPARSE_DIRECTION_COLUMN
121+
return CUSPARSE_DIRECTION_COLUMN
126122
else
127123
throw(ArgumentError("Unknown direction $dir"))
128124
end
129125
end
130126

131127
function Base.convert(::Type{cusparseOrder_t}, order::SparseChar)
132128
if order == 'R'
133-
CUSPARSE_ORDER_ROW
129+
return CUSPARSE_ORDER_ROW
134130
elseif order == 'C'
135-
CUSPARSE_ORDER_COL
131+
return CUSPARSE_ORDER_COL
136132
else
137133
throw(ArgumentError("Unknown order $order"))
138134
end
139135
end
140136

141137
function Base.convert(::Type{cusparseSpSVUpdate_t}, update::SparseChar)
142138
if update == 'G'
143-
CUSPARSE_SPSV_UPDATE_GENERAL
139+
return CUSPARSE_SPSV_UPDATE_GENERAL
144140
elseif update == 'D'
145-
CUSPARSE_SPSV_UPDATE_DIAGONAL
141+
return CUSPARSE_SPSV_UPDATE_DIAGONAL
146142
else
147143
throw(ArgumentError("Unknown update $update"))
148144
end
149145
end
150146

151147
function Base.convert(::Type{cusparseSpSMUpdate_t}, update::SparseChar)
152148
if update == 'G'
153-
CUSPARSE_SPSM_UPDATE_GENERAL
149+
return CUSPARSE_SPSM_UPDATE_GENERAL
154150
elseif update == 'D'
155-
CUSPARSE_SPSM_UPDATE_DIAGONAL
151+
return CUSPARSE_SPSM_UPDATE_DIAGONAL
156152
else
157153
throw(ArgumentError("Unknown update $update"))
158154
end

test/libraries/cusparse.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,3 +1189,46 @@ end
11891189
@test Array(coo.nzVal) == [12f0, 1f0]
11901190
end
11911191
end
1192+
1193+
@testset "Utility type conversions" begin
1194+
@test convert(CUSPARSE.cusparseIndexType_t, Int32) == CUSPARSE.CUSPARSE_INDEX_32I
1195+
@test convert(CUSPARSE.cusparseIndexType_t, Int64) == CUSPARSE.CUSPARSE_INDEX_64I
1196+
@test_throws ArgumentError("CUSPARSE type equivalent for index type Int8 does not exist!") convert(CUSPARSE.cusparseIndexType_t, Int8)
1197+
@test convert(Type, CUSPARSE.CUSPARSE_INDEX_32I) == Int32
1198+
@test convert(Type, CUSPARSE.CUSPARSE_INDEX_64I) == Int64
1199+
1200+
@test convert(CUSPARSE.cusparseIndexBase_t, 0) == CUSPARSE.CUSPARSE_INDEX_BASE_ZERO
1201+
@test convert(CUSPARSE.cusparseIndexBase_t, 1) == CUSPARSE.CUSPARSE_INDEX_BASE_ONE
1202+
@test_throws ArgumentError("CUSPARSE does not support index base 2!") convert(CUSPARSE.cusparseIndexBase_t, 2)
1203+
@test convert(Int8, CUSPARSE.CUSPARSE_INDEX_BASE_ZERO) == zero(Int8)
1204+
@test convert(Int8, CUSPARSE.CUSPARSE_INDEX_BASE_ONE) == one(Int8)
1205+
1206+
@test_throws ArgumentError("Unknown operation X") convert(CUSPARSE.cusparseOperation_t, CUSPARSE.SparseChar('X'))
1207+
1208+
@test convert(CUSPARSE.cusparseMatrixType_t, CUSPARSE.SparseChar('G')) == CUSPARSE.CUSPARSE_MATRIX_TYPE_GENERAL
1209+
@test convert(CUSPARSE.cusparseMatrixType_t, CUSPARSE.SparseChar('T')) == CUSPARSE.CUSPARSE_MATRIX_TYPE_TRIANGULAR
1210+
@test convert(CUSPARSE.cusparseMatrixType_t, CUSPARSE.SparseChar('S')) == CUSPARSE.CUSPARSE_MATRIX_TYPE_SYMMETRIC
1211+
@test convert(CUSPARSE.cusparseMatrixType_t, CUSPARSE.SparseChar('H')) == CUSPARSE.CUSPARSE_MATRIX_TYPE_HERMITIAN
1212+
@test_throws ArgumentError("Unknown matrix type X") convert(CUSPARSE.cusparseMatrixType_t, CUSPARSE.SparseChar('X'))
1213+
1214+
@test_throws ArgumentError("Unknown attribute X") convert(CUSPARSE.cusparseSpMatAttribute_t, CUSPARSE.SparseChar('X'))
1215+
@test_throws ArgumentError("Unknown fill mode X") convert(CUSPARSE.cusparseFillMode_t, CUSPARSE.SparseChar('X'))
1216+
@test_throws ArgumentError("Unknown diag type X") convert(CUSPARSE.cusparseDiagType_t, CUSPARSE.SparseChar('X'))
1217+
@test_throws ArgumentError("Unknown index base X") convert(CUSPARSE.cusparseIndexBase_t, CUSPARSE.SparseChar('X'))
1218+
1219+
@test convert(CUSPARSE.cusparseDirection_t, CUSPARSE.SparseChar('R')) == CUSPARSE.CUSPARSE_DIRECTION_ROW
1220+
@test convert(CUSPARSE.cusparseDirection_t, CUSPARSE.SparseChar('C')) == CUSPARSE.CUSPARSE_DIRECTION_COLUMN
1221+
@test_throws ArgumentError("Unknown direction X") convert(CUSPARSE.cusparseDirection_t, CUSPARSE.SparseChar('X'))
1222+
1223+
@test convert(CUSPARSE.cusparseOrder_t, CUSPARSE.SparseChar('R')) == CUSPARSE.CUSPARSE_ORDER_ROW
1224+
@test convert(CUSPARSE.cusparseOrder_t, CUSPARSE.SparseChar('C')) == CUSPARSE.CUSPARSE_ORDER_COL
1225+
@test_throws ArgumentError("Unknown order X") convert(CUSPARSE.cusparseOrder_t, CUSPARSE.SparseChar('X'))
1226+
1227+
@test convert(CUSPARSE.cusparseSpSVUpdate_t, CUSPARSE.SparseChar('G')) == CUSPARSE.CUSPARSE_SPSV_UPDATE_GENERAL
1228+
@test convert(CUSPARSE.cusparseSpSVUpdate_t, CUSPARSE.SparseChar('D')) == CUSPARSE.CUSPARSE_SPSV_UPDATE_DIAGONAL
1229+
@test_throws ArgumentError("Unknown update X") convert(CUSPARSE.cusparseSpSVUpdate_t, CUSPARSE.SparseChar('X'))
1230+
1231+
@test convert(CUSPARSE.cusparseSpSMUpdate_t, CUSPARSE.SparseChar('G')) == CUSPARSE.CUSPARSE_SPSV_UPDATE_GENERAL
1232+
@test convert(CUSPARSE.cusparseSpSMUpdate_t, CUSPARSE.SparseChar('D')) == CUSPARSE.CUSPARSE_SPSV_UPDATE_DIAGONAL
1233+
@test_throws ArgumentError("Unknown update X") convert(CUSPARSE.cusparseSpSMUpdate_t, CUSPARSE.SparseChar('X'))
1234+
end

0 commit comments

Comments
 (0)