Skip to content

Commit ab37c59

Browse files
authored
[BLAS] Add the gemmt routine (#51701)
We can finally add this routine, supported by the last version of OpenBLAS (3.24) and Intel MKL.
1 parent 4c9c37f commit ab37c59

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

docs/src/index.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ and define matrix-matrix operations.
730730
[Dongarra-1990]: https://dl.acm.org/doi/10.1145/77626.79170
731731

732732
```@docs
733+
LinearAlgebra.BLAS.gemmt!
734+
LinearAlgebra.BLAS.gemmt(::Any, ::Any, ::Any, ::Any, ::Any, ::Any)
735+
LinearAlgebra.BLAS.gemmt(::Any, ::Any, ::Any, ::Any, ::Any)
733736
LinearAlgebra.BLAS.gemm!
734737
LinearAlgebra.BLAS.gemm(::Any, ::Any, ::Any, ::Any, ::Any)
735738
LinearAlgebra.BLAS.gemm(::Any, ::Any, ::Any, ::Any)

src/blas.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ export
6363
# xSYR2
6464
# xSPR2
6565
# Level 3
66+
gemmt!,
67+
gemmt,
6668
gemm!,
6769
gemm,
6870
symm!,
@@ -1481,6 +1483,88 @@ end
14811483
# Level 3
14821484
## (GE) general matrix-matrix multiplication
14831485

1486+
"""
1487+
gemmt!(uplo, tA, tB, alpha, A, B, beta, C)
1488+
1489+
Update the lower or upper triangular part specified by [`uplo`](@ref stdlib-blas-uplo) of `C` as
1490+
`alpha*A*B + beta*C` or the other variants according to [`tA`](@ref stdlib-blas-trans) and `tB`.
1491+
Return the updated `C`.
1492+
1493+
!!! compat "Julia 1.11"
1494+
`gemmt!` requires at least Julia 1.11.
1495+
"""
1496+
function gemmt! end
1497+
1498+
for (gemmt, elty) in
1499+
((:dgemmt_,:Float64),
1500+
(:sgemmt_,:Float32),
1501+
(:zgemmt_,:ComplexF64),
1502+
(:cgemmt_,:ComplexF32))
1503+
@eval begin
1504+
# SUBROUTINE DGEMMT(UPLO,TRANSA,TRANSB,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
1505+
# * .. Scalar Arguments ..
1506+
# DOUBLE PRECISION ALPHA,BETA
1507+
# INTEGER K,LDA,LDB,LDC,N
1508+
# CHARACTER UPLO,TRANSA,TRANSB
1509+
# * .. Array Arguments ..
1510+
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
1511+
function gemmt!(uplo::AbstractChar, transA::AbstractChar, transB::AbstractChar,
1512+
alpha::Union{($elty), Bool},
1513+
A::AbstractVecOrMat{$elty}, B::AbstractVecOrMat{$elty},
1514+
beta::Union{($elty), Bool},
1515+
C::AbstractVecOrMat{$elty})
1516+
chkuplo(uplo)
1517+
require_one_based_indexing(A, B, C)
1518+
m = size(A, transA == 'N' ? 1 : 2)
1519+
ka = size(A, transA == 'N' ? 2 : 1)
1520+
kb = size(B, transB == 'N' ? 1 : 2)
1521+
n = size(B, transB == 'N' ? 2 : 1)
1522+
if ka != kb || m != n || m != size(C,1) || n != size(C,2)
1523+
throw(DimensionMismatch(lazy"A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
1524+
end
1525+
chkstride1(A)
1526+
chkstride1(B)
1527+
chkstride1(C)
1528+
ccall((@blasfunc($gemmt), libblastrampoline), Cvoid,
1529+
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
1530+
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt},
1531+
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
1532+
Ref{BlasInt}, Clong, Clong, Clong),
1533+
uplo, transA, transB, n,
1534+
ka, alpha, A, max(1,stride(A,2)),
1535+
B, max(1,stride(B,2)), beta, C,
1536+
max(1,stride(C,2)), 1, 1, 1)
1537+
C
1538+
end
1539+
function gemmt(uplo::AbstractChar, transA::AbstractChar, transB::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty})
1540+
gemmt!(uplo, transA, transB, alpha, A, B, zero($elty), similar(B, $elty, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1))))
1541+
end
1542+
function gemmt(uplo::AbstractChar, transA::AbstractChar, transB::AbstractChar, A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty})
1543+
gemmt(uplo, transA, transB, one($elty), A, B)
1544+
end
1545+
end
1546+
end
1547+
1548+
"""
1549+
gemmt(uplo, tA, tB, alpha, A, B)
1550+
1551+
Return the lower or upper triangular part specified by [`uplo`](@ref stdlib-blas-uplo) of `A*B` or the other three variants according to [`tA`](@ref stdlib-blas-trans) and `tB`.
1552+
1553+
!!! compat "Julia 1.11"
1554+
`gemmt` requires at least Julia 1.11.
1555+
"""
1556+
gemmt(uplo, tA, tB, alpha, A, B)
1557+
1558+
"""
1559+
gemmt(uplo, tA, tB, A, B)
1560+
1561+
Return the lower or upper triangular part specified by [`uplo`](@ref stdlib-blas-uplo) of `A*B` or the other three variants according to [`tA`](@ref stdlib-blas-trans) and `tB`.
1562+
1563+
!!! compat "Julia 1.11"
1564+
`gemmt` requires at least Julia 1.11.
1565+
"""
1566+
gemmt(uplo, tA, tB, A, B)
1567+
14841568
"""
14851569
gemm!(tA, tB, alpha, A, B, beta, C)
14861570

test/blas.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,40 @@ Random.seed!(100)
447447
end
448448
end
449449
end
450+
@testset "gemmt" begin
451+
for (wrapper, uplo) in ((LowerTriangular, 'L'), (UpperTriangular, 'U'))
452+
@test wrapper(BLAS.gemmt(uplo, 'N', 'N', I4, I4)) wrapper(I4)
453+
@test wrapper(BLAS.gemmt(uplo, 'N', 'T', I4, I4)) wrapper(I4)
454+
@test wrapper(BLAS.gemmt(uplo, 'T', 'N', I4, I4)) wrapper(I4)
455+
@test wrapper(BLAS.gemmt(uplo, 'T', 'T', I4, I4)) wrapper(I4)
456+
@test wrapper(BLAS.gemmt(uplo, 'N', 'N', el2, I4, I4)) wrapper(el2 * I4)
457+
@test wrapper(BLAS.gemmt(uplo, 'N', 'T', el2, I4, I4)) wrapper(el2 * I4)
458+
@test wrapper(BLAS.gemmt(uplo, 'T', 'N', el2, I4, I4)) wrapper(el2 * I4)
459+
@test wrapper(BLAS.gemmt(uplo, 'T', 'T', el2, I4, I4)) wrapper(el2 * I4)
460+
I4cp = copy(I4)
461+
@test wrapper(BLAS.gemmt!(uplo, 'N', 'N', one(elty), I4, I4, elm1, I4cp)) wrapper(Z4)
462+
@test I4cp Z4
463+
I4cp[:] = I4
464+
@test wrapper(BLAS.gemmt!(uplo, 'N', 'T', one(elty), I4, I4, elm1, I4cp)) wrapper(Z4)
465+
@test I4cp Z4
466+
I4cp[:] = I4
467+
@test wrapper(BLAS.gemmt!(uplo, 'T', 'N', one(elty), I4, I4, elm1, I4cp)) wrapper(Z4)
468+
@test I4cp Z4
469+
I4cp[:] = I4
470+
@test wrapper(BLAS.gemmt!(uplo, 'T', 'T', one(elty), I4, I4, elm1, I4cp)) wrapper(Z4)
471+
@test I4cp Z4
472+
M1 = uplo == 'U' ? U4 : I4
473+
@test wrapper(BLAS.gemmt(uplo, 'N', 'N', I4, U4)) wrapper(M1)
474+
M2 = uplo == 'U' ? I4 : U4'
475+
@test wrapper(BLAS.gemmt(uplo, 'N', 'T', I4, U4)) wrapper(M2)
476+
@test_throws DimensionMismatch BLAS.gemmt!(uplo, 'N', 'N', one(elty), I43, I4, elm1, I43)
477+
@test_throws DimensionMismatch BLAS.gemmt!(uplo, 'N', 'N', one(elty), I4, I4, elm1, Matrix{elty}(I, 5, 5))
478+
@test_throws DimensionMismatch BLAS.gemmt!(uplo, 'N', 'N', one(elty), I43, I4, elm1, I4)
479+
@test_throws DimensionMismatch BLAS.gemmt!(uplo, 'T', 'N', one(elty), I4, I43, elm1, I43)
480+
@test_throws DimensionMismatch BLAS.gemmt!(uplo, 'N', 'T', one(elty), I43, I43, elm1, I43)
481+
@test_throws DimensionMismatch BLAS.gemmt!(uplo, 'T', 'T', one(elty), I43, I43, elm1, Matrix{elty}(I, 3, 4))
482+
end
483+
end
450484
@testset "gemm" begin
451485
@test all(BLAS.gemm('N', 'N', I4, I4) .== I4)
452486
@test all(BLAS.gemm('N', 'T', I4, I4) .== I4)
@@ -455,7 +489,7 @@ Random.seed!(100)
455489
@test all(BLAS.gemm('N', 'N', el2, I4, I4) .== el2 * I4)
456490
@test all(BLAS.gemm('N', 'T', el2, I4, I4) .== el2 * I4)
457491
@test all(BLAS.gemm('T', 'N', el2, I4, I4) .== el2 * I4)
458-
@test all(LinearAlgebra.BLAS.gemm('T', 'T', el2, I4, I4) .== el2 * I4)
492+
@test all(BLAS.gemm('T', 'T', el2, I4, I4) .== el2 * I4)
459493
I4cp = copy(I4)
460494
@test all(BLAS.gemm!('N', 'N', one(elty), I4, I4, elm1, I4cp) .== Z4)
461495
@test all(I4cp .== Z4)

0 commit comments

Comments
 (0)