From e3897383a05c68cb8212171a610c14c409c6e102 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 11:48:11 -0600 Subject: [PATCH 1/9] [oneMKL] Interface variants of trsm! and trmm! --- lib/mkl/linalg.jl | 52 +++++++++++++++++++++++++---- lib/mkl/wrappers_blas.jl | 70 ++++++++++++++++++++++++++++++++++++++++ test/onemkl.jl | 24 ++++++++++++++ 3 files changed, 139 insertions(+), 7 deletions(-) diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index 6848a616..25b07075 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -5,7 +5,7 @@ using LinearAlgebra: Transpose, Adjoint, Hermitian, Symmetric, LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitUpperTriangular, - MulAddMul, wrap + UpperOrLowerTriangular, MulAddMul, wrap # # BLAS 1 @@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end +const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}} + +function LinearAlgebra.generic_trimatmul!( + C::oneStridedMatrix{T}, uplocA, isunitcA, + tfunA::Function, A::oneStridedMatrix{T}, + triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}}, +) where {T<:onemklFloat} + uplocB = LinearAlgebra.uplo_char(triB) + isunitcB = LinearAlgebra.isunit_char(triB) + B = parent(triB) + tfunB = LinearAlgebra.wrapperop(B) + transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C' + transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C' + if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper + triu!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower + tril!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N' + # operation is reversed to avoid executing the tranpose + triu!(A) + trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C) + elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' + tril!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' + triu!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N' + tril!(A) + trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C) + else + throw("mixed triangular-triangular multiplication") # TODO: rethink + end + return C +end + # triangular LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) + trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C) LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) -LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) -LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) + trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C) +LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} = + trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C) +LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C) diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index 4e038372..3fd7e158 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -1139,6 +1139,76 @@ function trsm(side::Char, trsm!(side, uplo, transa, diag, alpha, A, copy(B)) end +for (mmname_variant, smname_variant, elty) in + ((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64), + (:onemklStrmm_variant, :onemklStrsm_variant, :Float32), + (:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64), + (:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32)) + @eval begin + function trmm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + beta::Number, + A::oneStridedMatrix{$elty}, + B::oneStridedMatrix{$elty}, + C::oneStridedMatrix{$elty}) + m, n = size(B) + mA, nA = size(A) + if mA != nA throw(DimensionMismatch("A must be square")) end + if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end + lda = max(1,stride(A,2)) + ldb = max(1,stride(B,2)) + ldc = max(1,stride(C,2)) + queue = global_queue(context(A), device()) + $mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, c, ldc) + B + end + + function trsm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + beta::Number, + A::oneStridedMatrix{$elty}, + B::oneStridedMatrix{$elty}, + C::oneStridedMatrix{$elty}) + m, n = size(B) + mA, nA = size(A) + if mA != nA throw(DimensionMismatch("A must be square")) end + if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end + lda = max(1,stride(A,2)) + ldb = max(1,stride(B,2)) + ldc = max(1,stride(C,2)) + queue = global_queue(context(A), device()) + $smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, c, ldc) + B + end + end +end +function trmm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + A::oneStridedMatrix{$elty}, + B::oneStridedMatrix{$elty}, + C::oneStridedMatrix{$elty}) where T + trmm!(side, uplo, transa, diag, alpha, one(T), A, B, C) +end +function trsm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + A::oneStridedMatrix{$elty}, + B::oneStridedMatrix{$elty}, + C::oneStridedMatrix{$elty}) where T + trsm!(side, uplo, transa, diag, alpha, one(T), A, B, C) +end + ## hemm for (fname, elty) in ((:onemklZhemm,:ComplexF64), (:onemklChemm,:ComplexF32)) diff --git a/test/onemkl.jl b/test/onemkl.jl index 19f06926..b2e15e51 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -661,6 +661,14 @@ end # move to host and compare h_C = Array(dB) @test C ≈ h_C + + C = rand(T,m,n) + dC = oneArray(C) + beta = rand(T) + oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC) + h_C = Array(dC) + D = alpha*A*B + beta*C + @test D ≈ h_C end @testset "trmm" begin @@ -684,6 +692,14 @@ end dC = copy(dB) oneMKL.trsm!('L','U','N','N',alpha,dA,dC) @test C ≈ Array(dC) + + C = rand(T,m,n) + dC = oneArray(C) + beta = rand(T) + oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) + h_C = Array(dC) + D = alpha*(A\B) + beta*C + @test D ≈ h_C end @testset "left trsm" begin @@ -725,6 +741,14 @@ end dC = copy(dA) oneMKL.trsm!('R','U','N','N',alpha,dB,dC) @test C ≈ Array(dC) + + C = rand(T,m,n) + dC = oneArray(C) + beta = rand(T) + oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) + h_C = Array(dC) + D = alpha*(A/B) + beta*C + @test D ≈ h_C end @testset "right trsm" begin From 1b085be89760e7b3c2b189a065242b57a8ec0ead Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 12:00:28 -0600 Subject: [PATCH 2/9] Update wrappers_blas.jl --- lib/mkl/wrappers_blas.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index 3fd7e158..90c2ad18 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -1193,9 +1193,9 @@ function trmm!(side::Char, transa::Char, diag::Char, alpha::Number, - A::oneStridedMatrix{$elty}, - B::oneStridedMatrix{$elty}, - C::oneStridedMatrix{$elty}) where T + A::oneStridedMatrix{T}, + B::oneStridedMatrix{T}, + C::oneStridedMatrix{T}) where T trmm!(side, uplo, transa, diag, alpha, one(T), A, B, C) end function trsm!(side::Char, @@ -1203,9 +1203,9 @@ function trsm!(side::Char, transa::Char, diag::Char, alpha::Number, - A::oneStridedMatrix{$elty}, - B::oneStridedMatrix{$elty}, - C::oneStridedMatrix{$elty}) where T + A::oneStridedMatrix{T}, + B::oneStridedMatrix{T}, + C::oneStridedMatrix{T}) where T trsm!(side, uplo, transa, diag, alpha, one(T), A, B, C) end From d9831beb2130115ca9ebd7e744dfb492a95c77c3 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 12:07:18 -0600 Subject: [PATCH 3/9] Update wrappers_blas.jl --- lib/mkl/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index 25b07075..db16da69 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -1,7 +1,7 @@ # interfacing with LinearAlgebra standard library import LinearAlgebra -using LinearAlgebra: Transpose, Adjoint, +using LinearAlgebra: Transpose, Adjoint, AdjOrTrans, Hermitian, Symmetric, LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitUpperTriangular, From 61643b4ad933d9776ab22fd95ac639705be08826 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 13:09:03 -0600 Subject: [PATCH 4/9] Update wrappers_blas.jl --- lib/mkl/wrappers_blas.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index 90c2ad18..cf29b5ad 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -1162,7 +1162,7 @@ for (mmname_variant, smname_variant, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) queue = global_queue(context(A), device()) - $mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, c, ldc) + $mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc) B end @@ -1183,7 +1183,7 @@ for (mmname_variant, smname_variant, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) queue = global_queue(context(A), device()) - $smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, c, ldc) + $smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc) B end end From 4fcc1a9463cfae55b4a4b429e814d43a89c72d9e Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 13:19:12 -0600 Subject: [PATCH 5/9] Update wrappers_blas.jl --- test/onemkl.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/onemkl.jl b/test/onemkl.jl index b2e15e51..c280924b 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -664,7 +664,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = rand(T) + beta = one(T) # rand(T) oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*A*B + beta*C @@ -695,7 +695,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = rand(T) + beta = one(T) # rand(T) oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A\B) + beta*C @@ -744,7 +744,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = rand(T) + beta = one(T) # rand(T) oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A/B) + beta*C From db5d54ca9b303c383941b38c50819cca0ef8e224 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 14:11:51 -0600 Subject: [PATCH 6/9] Update wrappers_blas.jl --- test/onemkl.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/onemkl.jl b/test/onemkl.jl index c280924b..3840d2bb 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -664,7 +664,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = one(T) # rand(T) + beta = zero(T) # rand(T) oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*A*B + beta*C @@ -695,7 +695,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = one(T) # rand(T) + beta = zero(T) # rand(T) oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A\B) + beta*C @@ -744,7 +744,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = one(T) # rand(T) + beta = zero(T) # rand(T) oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A/B) + beta*C From 9241d8e30fa5e349d8e73ddeddf1c37683e650ba Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 14:12:40 -0600 Subject: [PATCH 7/9] Update wrappers_blas.jl --- lib/mkl/wrappers_blas.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index cf29b5ad..e01ffd2b 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -1196,7 +1196,7 @@ function trmm!(side::Char, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}, C::oneStridedMatrix{T}) where T - trmm!(side, uplo, transa, diag, alpha, one(T), A, B, C) + trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C) end function trsm!(side::Char, uplo::Char, @@ -1206,7 +1206,7 @@ function trsm!(side::Char, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}, C::oneStridedMatrix{T}) where T - trsm!(side, uplo, transa, diag, alpha, one(T), A, B, C) + trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C) end ## hemm From eacbd11ae085df510bbebb17770bc3eb8021709e Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 14:29:50 -0600 Subject: [PATCH 8/9] Update wrappers_blas.jl --- test/onemkl.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/onemkl.jl b/test/onemkl.jl index 3840d2bb..80aa2a8e 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -664,7 +664,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = zero(T) # rand(T) + beta = rand(T) oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*A*B + beta*C @@ -695,7 +695,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = zero(T) # rand(T) + beta = rand(T) oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A\B) + beta*C @@ -742,9 +742,9 @@ end oneMKL.trsm!('R','U','N','N',alpha,dB,dC) @test C ≈ Array(dC) - C = rand(T,m,n) + C = rand(T,m,m) dC = oneArray(C) - beta = zero(T) # rand(T) + beta = rand(T) oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A/B) + beta*C From 08e7dfaf3f3c19862daa60dab44778699f21b669 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 4 Nov 2024 14:44:03 -0600 Subject: [PATCH 9/9] Update test/onemkl.jl --- test/onemkl.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/onemkl.jl b/test/onemkl.jl index 80aa2a8e..2f61bc7b 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -664,7 +664,7 @@ end C = rand(T,m,n) dC = oneArray(C) - beta = rand(T) + beta = zero(T) # rand(T) oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*A*B + beta*C @@ -745,7 +745,7 @@ end C = rand(T,m,m) dC = oneArray(C) beta = rand(T) - oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) + oneMKL.trsm!('R','U','N','N',alpha,beta,dA,dB,dC) h_C = Array(dC) D = alpha*(A/B) + beta*C @test D ≈ h_C