Skip to content

Commit 554dcc4

Browse files
authored
Support sparse opnorm (#1466)
1 parent 2987086 commit 554dcc4

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

lib/cusparse/CUSPARSE.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include("libcusparse_deprecated.jl")
2727
include("array.jl")
2828
include("util.jl")
2929
include("types.jl")
30+
include("linalg.jl")
3031

3132
# low-level wrappers
3233
include("helpers.jl")

lib/cusparse/linalg.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using LinearAlgebra
2+
3+
function sum_dim1(A::CuSparseMatrixCSR)
4+
function kernel(Tnorm, out, dA)
5+
idx = (blockIdx().x-1) * blockDim().x + threadIdx().x
6+
idx < length(dA.rowPtr) || return
7+
s = zero(Tnorm)
8+
for k in dA.rowPtr[idx]:dA.rowPtr[idx+1]-1
9+
s += abs(dA.nzVal[k])
10+
end
11+
out[idx] = s
12+
return
13+
end
14+
15+
m, n = size(A)
16+
Tnorm = typeof(float(real(zero(eltype(A)))))
17+
Tsum = promote_type(Float64,Tnorm)
18+
rowsum = CUDA.CuArray{Tsum}(undef, m)
19+
kernel_f = @cuda launch=false kernel(Tnorm, rowsum, A)
20+
21+
config = launch_configuration(kernel_f.fun)
22+
threads = min(n, config.threads)
23+
blocks = cld(n, threads)
24+
kernel_f(Tnorm, rowsum, A; threads, blocks)
25+
return rowsum
26+
end
27+
28+
function LinearAlgebra.opnorm(A::CuSparseMatrixCSR, p::Real=2)
29+
if p == Inf
30+
return maximum(sum_dim1(A))
31+
else
32+
error("p=$p is not supported")
33+
end
34+
end
35+
36+
LinearAlgebra.opnorm(A::CuSparseMatrixCSC, p::Real=2) = opnorm(CuSparseMatrixCSR(A), p)

test/cusparse/linalg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using CUDA.CUSPARSE
2+
using LinearAlgebra, SparseArrays
3+
4+
@testset "opnorm" for T in [Float32, Float64, ComplexF32, ComplexF64]
5+
S = sprand(T, 10, 10, 0.1)
6+
dS_csc = CuSparseMatrixCSC(S)
7+
dS_csr = CuSparseMatrixCSR(S)
8+
@test opnorm(S, Inf) opnorm(dS_csc, Inf)
9+
@test opnorm(S, Inf) opnorm(dS_csr, Inf)
10+
end

0 commit comments

Comments
 (0)