Skip to content

fixes the kron implementation for sparse + diagonal matrix #2804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

tam724
Copy link

@tam724 tam724 commented Jun 27, 2025

This generalizes the implementation of kron for the combination of a (CUDA) sparse matrix and diagonal matrix.

Currently the Diagonal type is treated as a I (UniformScaling) with certain dimension (n x n). This is not the intended use of Diagonal (julia docs) and the following code return a wrong result when using CUDA:

julia> using CUDA, SparseArrays, LinearAlgebra
julia> A = sparse(ones(1, 1))
1×1 SparseMatrixCSC{Float64, Int64} with 1 stored entry:
 1.0
julia> B = Diagonal(rand(1))
1×1 Diagonal{Float64, Vector{Float64}}:
 0.4618063241112559
julia> kron(A, B)
1×1 SparseMatrixCSC{Float64, Int64} with 1 stored entry:
 0.4618063241112559
julia> kron(A |> cu, B |> cu)
1×1 CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32} with 1 stored entry:
 1.0

This implementation keeps the old behaviour: multiplication with an I(3)::Diagonal{Bool, Vector{Bool}} is still interpreted as multiplication with the identity.
Multiplication with a Diagonal matrix should be fixed. I also added some tests.

Copy link
Contributor

github-actions bot commented Jun 27, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl
index 55b9cd962..6151ab238 100644
--- a/lib/cublas/linalg.jl
+++ b/lib/cublas/linalg.jl
@@ -366,7 +366,7 @@ function LinearAlgebra.inv(D::Diagonal{T, <:CuArray{T}}) where {T}
     Diagonal(Di)
 end
 
-LinearAlgebra.adjoint(D::Diagonal{T, <:CuVector{T}}) where T <: Complex = Diagonal(map(adjoint, D.diag))
+LinearAlgebra.adjoint(D::Diagonal{T, <:CuVector{T}}) where {T <: Complex} = Diagonal(map(adjoint, D.diag))
 
 LinearAlgebra.rdiv!(A::CuArray, D::Diagonal) =  _rdiv!(A, A, D)
 
diff --git a/lib/cusparse/linalg.jl b/lib/cusparse/linalg.jl
index 803cfeed8..68f2465b1 100644
--- a/lib/cusparse/linalg.jl
+++ b/lib/cusparse/linalg.jl
@@ -66,13 +66,13 @@ _kron_CuSparseMatrixCOO_components(At::Transpose{<:Number, <:CuSparseMatrixCOO})
 _kron_CuSparseMatrixCOO_components(Ah::Adjoint{<:Number, <:CuSparseMatrixCOO}) = parent(Ah).colInd, parent(Ah).rowInd, parent(Ah).nzVal, adjoint, Int(parent(Ah).nnz)
 
 function LinearAlgebra.kron(
-    A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
-    B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
+        A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
+        B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
     ) where {TvA, TiA, TvB, TiB}
     mA, nA = size(A)
     mB, nB = size(B)
     Ti = promote_type(TiA, TiB)
-    Tv = typeof(oneunit(TvA)*oneunit(TvB))
+    Tv = typeof(oneunit(TvA) * oneunit(TvB))
 
     A_rowInd, A_colInd, A_nzVal, A_nzOp, A_nnz = _kron_CuSparseMatrixCOO_components(A)
     B_rowInd, B_colInd, B_nzVal, B_nzOp, B_nnz = _kron_CuSparseMatrixCOO_components(B)
@@ -91,13 +91,13 @@ function LinearAlgebra.kron(
 end
 
 function LinearAlgebra.kron(
-    A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
-    B::Diagonal{TvB, <:CuVector{TvB}}
+        A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
+        B::Diagonal{TvB, <:CuVector{TvB}}
     ) where {TvA, TiA, TvB}
     mA, nA = size(A)
     mB, nB = size(B)
     Ti = TiA
-    Tv = typeof(oneunit(TvA)*oneunit(TvB))
+    Tv = typeof(oneunit(TvA) * oneunit(TvB))
 
     A_rowInd, A_colInd, A_nzVal, A_nzOp, A_nnz = _kron_CuSparseMatrixCOO_components(A)
     B_rowInd, B_colInd, B_nzVal, B_nzOp, B_nnz = one(Ti):Ti(nB), one(Ti):Ti(nB), B.diag, identity, Int(nB)
@@ -116,13 +116,13 @@ function LinearAlgebra.kron(
 end
 
 function LinearAlgebra.kron(
-    A::Diagonal{TvA, <:CuVector{TvA}},
-    B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
+        A::Diagonal{TvA, <:CuVector{TvA}},
+        B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
     ) where {TvA, TvB, TiB}
     mA, nA = size(A)
     mB, nB = size(B)
     Ti = TiB
-    Tv = typeof(oneunit(TvA)*oneunit(TvB))
+    Tv = typeof(oneunit(TvA) * oneunit(TvB))
 
     A_rowInd, A_colInd, A_nzVal, A_nzOp, A_nnz = one(Ti):Ti(nA), one(Ti):Ti(nA), A.diag, identity, Int(nA)
     B_rowInd, B_colInd, B_nzVal, B_nzOp, B_nnz = _kron_CuSparseMatrixCOO_components(B)
diff --git a/test/libraries/cublas/extensions.jl b/test/libraries/cublas/extensions.jl
index 027b9678e..04a58e74d 100644
--- a/test/libraries/cublas/extensions.jl
+++ b/test/libraries/cublas/extensions.jl
@@ -538,7 +538,7 @@ k = 13
             mul!(d_XA, d_X, d_A)
             Array(d_XA) ≈ Diagonal(x) * A
 
-            XA = rand(elty,m,n)
+            XA = rand(elty, m, n)
             d_XA = CuArray(XA)
             d_X = Diagonal(d_x)
             mul!(d_XA, d_X', d_A)
@@ -549,8 +549,8 @@ k = 13
             d_Y = Diagonal(d_y)
             mul!(d_AY, d_A, d_Y)
             Array(d_AY) ≈ A * Diagonal(y)
-            
-            AY = rand(elty,m,n)
+
+            AY = rand(elty, m, n)
             d_AY = CuArray(AY)
             d_Y = Diagonal(d_y)
             mul!(d_AY, d_A, d_Y')
diff --git a/test/libraries/cusparse/linalg.jl b/test/libraries/cusparse/linalg.jl
index 9af0966fa..bb6b9c27a 100644
--- a/test/libraries/cusparse/linalg.jl
+++ b/test/libraries/cusparse/linalg.jl
@@ -44,7 +44,7 @@ m = 10
             @test collect(kron(opa(dZA), C)) ≈ kron(opa(ZA), C)
             @test collect(kron(C, opa(dZA))) ≈ kron(C, opa(ZA))
         end
-        @testset "kronecker product with Diagonal opa = $opa" for opa in (identity, transpose, adjoint) 
+        @testset "kronecker product with Diagonal opa = $opa" for opa in (identity, transpose, adjoint)
             @test collect(kron(opa(dA), dD)) ≈ kron(opa(A), D)
             @test collect(kron(dD, opa(dA))) ≈ kron(D, opa(A))
             @test collect(kron(opa(dZA), dD)) ≈ kron(opa(ZA), D)
@@ -57,14 +57,14 @@ end
     mat_sizes = [(2, 3), (2, 0)]
     @testset "size(A) = ($(mA), $(nA)), size(B) = ($(mB), $(nB))" for (mA, nA) in mat_sizes, (mB, nB) in mat_sizes
         A = sprand(T, mA, nA, 0.5)
-        B  = sprand(T, mB, nB, 0.5)
+        B = sprand(T, mB, nB, 0.5)
 
         dA = CuSparseMatrixCOO{T}(A)
         dB = CuSparseMatrixCOO{T}(B)
 
         @testset "kronecker (COO ⊗ COO) opa = $opa, opb = $opb" for opa in (identity, transpose, adjoint), opb in (identity, transpose, adjoint)
             dC = kron(opa(dA), opb(dB))
-            @test collect(dC)  ≈ kron(opa(A), opb(B))
+            @test collect(dC) ≈ kron(opa(A), opb(B))
             @test eltype(dC) == typeof(oneunit(T) * oneunit(T))
             @test dC isa CuSparseMatrixCOO
         end
@@ -73,20 +73,20 @@ end
 
 @testset "TA = $TA, TvB = $TvB" for TvB in [Float32, Float64, ComplexF32, ComplexF64], TA in [Bool, TvB]
     A = Diagonal(rand(TA, 2))
-    B  = sprand(TvB, 3, 4, 0.5)
+    B = sprand(TvB, 3, 4, 0.5)
     dA = adapt(CuArray, A)
     dB = CuSparseMatrixCOO{TvB}(B)
 
-    @testset "kronecker (diagonal ⊗ COO) opa = $opa, opb = $opb" for opa in (adjoint, ), opb in (identity, transpose, adjoint)
+    @testset "kronecker (diagonal ⊗ COO) opa = $opa, opb = $opb" for opa in (adjoint,), opb in (identity, transpose, adjoint)
         dC = kron(opa(dA), opb(dB))
-        @test collect(dC)  ≈ kron(opa(A), opb(B))
+        @test collect(dC) ≈ kron(opa(A), opb(B))
         @test eltype(dC) == typeof(oneunit(TA) * oneunit(TvB))
         @test dC isa CuSparseMatrixCOO
     end
 
-    @testset "kronecker (COO ⊗ diagonal) opa = $opa, opb = $opb" for opa in (identity, transpose, adjoint), opb in (adjoint, )
+    @testset "kronecker (COO ⊗ diagonal) opa = $opa, opb = $opb" for opa in (identity, transpose, adjoint), opb in (adjoint,)
         dC = kron(opb(dB), opa(dA))
-        @test collect(dC)  ≈ kron(opb(B), opa(A))
+        @test collect(dC) ≈ kron(opb(B), opa(A))
         @test eltype(dC) == typeof(oneunit(TvB) * oneunit(TA))
         @test dC isa CuSparseMatrixCOO
     end

@tam724
Copy link
Author

tam724 commented Jun 27, 2025

For reference: here the behaviour of Diagonal{Bool} = identity matrix is mentioned. JuliaGPU/GPUArrays.jl#585

However, not even for all instances of Diagonal{Bool} this implementation of kron is correct, because the diagonal could contain false elements:

julia> Diagonal([true, false, true])
3×3 Diagonal{Bool, Vector{Bool}}:
 1    
   0  
     1

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Benchmark suite Current: be66023 Previous: 740e888 Ratio
latency/precompile 43253783558 ns 43145037296.5 ns 1.00
latency/ttfp 7174476859 ns 7113018957 ns 1.01
latency/import 3453137369 ns 3449154656 ns 1.00
integration/volumerhs 9617230.5 ns 9623415.5 ns 1.00
integration/byval/slices=1 147230 ns 147533 ns 1.00
integration/byval/slices=3 425809 ns 426571.5 ns 1.00
integration/byval/reference 145237 ns 145493 ns 1.00
integration/byval/slices=2 286727 ns 286772 ns 1.00
integration/cudadevrt 103723 ns 103833 ns 1.00
kernel/indexing 14406 ns 14624 ns 0.99
kernel/indexing_checked 15140 ns 15244 ns 0.99
kernel/occupancy 724.4475524475524 ns 721.1958041958042 ns 1.00
kernel/launch 2395.1111111111113 ns 2555.166666666667 ns 0.94
kernel/rand 18537 ns 15614 ns 1.19
array/reverse/1d 19815.5 ns 20478 ns 0.97
array/reverse/2d 24843 ns 25648 ns 0.97
array/reverse/1d_inplace 11480 ns 11094 ns 1.03
array/reverse/2d_inplace 13291 ns 12835 ns 1.04
array/copy 20904 ns 21710 ns 0.96
array/iteration/findall/int 157537 ns 161312.5 ns 0.98
array/iteration/findall/bool 138545 ns 141552 ns 0.98
array/iteration/findfirst/int 162036 ns 163967.5 ns 0.99
array/iteration/findfirst/bool 164125 ns 165082.5 ns 0.99
array/iteration/scalar 71579 ns 76072 ns 0.94
array/iteration/logical 216126 ns 224476 ns 0.96
array/iteration/findmin/1d 46196 ns 48162 ns 0.96
array/iteration/findmin/2d 96597 ns 98065 ns 0.99
array/reductions/reduce/1d 35779 ns 43404.5 ns 0.82
array/reductions/reduce/2d 41877 ns 49318 ns 0.85
array/reductions/mapreduce/1d 34004 ns 40598 ns 0.84
array/reductions/mapreduce/2d 41217.5 ns 52219.5 ns 0.79
array/broadcast 21061 ns 21548 ns 0.98
array/copyto!/gpu_to_gpu 12561 ns 13067 ns 0.96
array/copyto!/cpu_to_gpu 215828 ns 218217 ns 0.99
array/copyto!/gpu_to_cpu 283456 ns 286538 ns 0.99
array/accumulate/1d 108764 ns 110985 ns 0.98
array/accumulate/2d 80472 ns 81897 ns 0.98
array/construct 1257.6 ns 1280.15 ns 0.98
array/random/randn/Float32 45513 ns 49468 ns 0.92
array/random/randn!/Float32 24995 ns 25315 ns 0.99
array/random/rand!/Int64 27176 ns 27290 ns 1.00
array/random/rand!/Float32 8715.666666666666 ns 8982.333333333334 ns 0.97
array/random/rand/Int64 30014 ns 30478 ns 0.98
array/random/rand/Float32 13055.5 ns 13445 ns 0.97
array/permutedims/4d 61428 ns 61800 ns 0.99
array/permutedims/2d 55137 ns 56060 ns 0.98
array/permutedims/3d 55868 ns 56903.5 ns 0.98
array/sorting/1d 2777015.5 ns 2768176.5 ns 1.00
array/sorting/by 3367995 ns 3356819.5 ns 1.00
array/sorting/2d 1086288 ns 1083702 ns 1.00
cuda/synchronization/stream/auto 1040.9 ns 1025.9 ns 1.01
cuda/synchronization/stream/nonblocking 8009.700000000001 ns 8263.8 ns 0.97
cuda/synchronization/stream/blocking 806.1530612244898 ns 801.5851063829788 ns 1.01
cuda/synchronization/context/auto 1178.2 ns 1171.1 ns 1.01
cuda/synchronization/context/nonblocking 7166.2 ns 8327.9 ns 0.86
cuda/synchronization/context/blocking 922.5106382978723 ns 913.5490196078431 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a single nit.

@testset "type = $typ" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
dA = typ(A)
dB = typ(B)
dZA = typ(ZA)
dD = Diagonal(CuArray(D.diag))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapt(CuArray, D) is more idiomatic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

@@ -122,7 +123,8 @@ function LinearAlgebra.kron(A::Diagonal, B::CuSparseMatrixCOO{T, Ti}) where {Ti,
row = CuVector(repeat(row, inner = Bnnz))
col = (0:nA-1) .* nB
col = CuVector(repeat(col, inner = Bnnz))
data = repeat(CUDA.ones(T, nA), inner = Bnnz)
Adiag = (TA == Bool) ? CUDA.ones(T, nA) : A.diag
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also need an opinion about these lines. This still leads to unexpected behaviour, when the diagonal of B contains false elements. But first I did not want to break with previous behaviour of the function.
My idea for a clean solution would be to only allow Diagonal{TD, <:CuVector{TD}}, but this would break code that depends on the previous behaviour.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we implement the Bool case as multiplication by diag? Relying on false elements still performing multiplication as if by I seems like something we shouldn't support. The behavior here should simply match Base, so I guess we can remove the Bool/I special casing altogether?

Copy link
Author

@tam724 tam724 Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently somebody might expect kron(cu(sprand(2, 2, 1.0)), I(2)) to work. This would error afterwards.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works with SparseArrays.jl, so would have to work here too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it?

julia> typeof(I(2))
Diagonal{Bool, Vector{Bool}}

I would expect the CUDA kron to work with the CUDA Diagonal

julia> typeof(cu(I(2)))
Diagonal{Bool, CuArray{Bool, 1, CUDA.DeviceMemory}}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following. I'm saying that because kron(sprand(2, 2, 1.0), I(2)) works on the CPU, with Array, it should work on the GPU with CuArray.

Copy link
Author

@tam724 tam724 Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion, but I think we are on the same page.
Before this PR the method allowed kron for the combination of a CuArray (sparse) and a CPU Diagonal (by ignoring the diag field of Diagonal and incorrectly assuming that Diagonal represents sized identity matrix).

Handling the Diagonal correctly would disallow this combination but break with the previous behaviour. If we are okay with that, CUDA.jl should implement

LinearAlgebra.kron(A::CuSparseMatrixCOO{<:Number, <:Integer}, B::Diagonal{<:Number, <:CuVector{<:Number})

That would break user code that "worked" before this PR, like:

julia> A = cu(sprand(2, 2, 1.0))
2×2 CuSparseMatrixCSC{Float32, Int32} with 4 stored entries:
 0.98641366  0.92539436
 0.80849653  0.48799357

julia> kron(A, I(2)) # after this PR this would error with ERROR: Scalar indexing is disallowed.
4×4 CuSparseMatrixCSC{Float32, Int32} with 8 stored entries:
 0.98641366             0.92539436    
            0.98641366             0.92539436
 0.80849653             0.48799357    
            0.80849653             0.48799357

julia> kron(A, cu(I(2)) # resolve by moving the Diagonal to the GPU first

I'm hesitant because I don't want to break others code. And because this was also tested behaviour:

@testset "kronecker product with I opa = $opa" for opa in (identity, transpose, adjoint)
@test collect(kron(opa(dA), C)) kron(opa(A), C)
@test collect(kron(C, opa(dA))) kron(C, opa(A))
@test collect(kron(opa(dZA), C)) kron(opa(ZA), C)
@test collect(kron(C, opa(dZA))) kron(C, opa(ZA))
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the time right now to look at this myself, so sorry for the naive questions, but why isn't it possible to support both I inputs as well as Diagonal ones, with the latter respecting the actual diagonal values, by correctly determining the value to broadcast instead of hard-coding CUDA.ones? Or, worst case, by using two different methods?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants