Skip to content

Commit 5efcee6

Browse files
authored
Add specializations for some triangular-triangular multiplications (#1538)
1 parent 29afec2 commit 5efcee6

File tree

2 files changed

+101
-1
lines changed

2 files changed

+101
-1
lines changed

lib/cublas/linalg.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,68 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'U', 'N'),
443443
end
444444
end
445445

446+
function LinearAlgebra.mul!(X::DenseCuMatrix{T},
447+
A::LowerTriangular{T,<:DenseCuMatrix},
448+
B::UpperTriangular{T,<:DenseCuMatrix},
449+
) where {T<:CublasFloat}
450+
triu!(parent(B))
451+
trmm!('L', 'L', 'N', 'N', one(T), parent(A), parent(B), parent(X))
452+
X
453+
end
454+
455+
function LinearAlgebra.mul!(X::DenseCuMatrix{T},
456+
A::UpperTriangular{T,<:DenseCuMatrix},
457+
B::LowerTriangular{T,<:DenseCuMatrix},
458+
) where {T<:CublasFloat}
459+
tril!(parent(B))
460+
trmm!('L', 'U', 'N', 'N', one(T), parent(A), parent(B), parent(X))
461+
X
462+
end
463+
464+
for (trtype, valtype) in ((:Transpose, :CublasFloat),
465+
(:Adjoint, :CublasReal),
466+
(:Adjoint, :CublasComplex))
467+
@eval begin
468+
function LinearAlgebra.mul!(X::DenseCuMatrix{T},
469+
A::UpperTriangular{T,<:DenseCuMatrix},
470+
B::LowerTriangular{<:Any,<:$trtype{T,<:DenseCuMatrix}},
471+
) where {T<:$valtype}
472+
# operation is reversed to avoid executing the tranpose
473+
triu!(parent(A))
474+
CUBLAS.trmm!('R', 'U', 'T', 'N', one(T), parent(parent(B)), parent(A), parent(X))
475+
X
476+
end
477+
478+
function LinearAlgebra.mul!(X::DenseCuMatrix{T},
479+
A::UpperTriangular{<:Any,<:$trtype{T,<:DenseCuMatrix}},
480+
B::LowerTriangular{T,<:DenseCuMatrix},
481+
) where {T<:$valtype}
482+
tril!(parent(B))
483+
CUBLAS.trmm!('L', 'L', 'T', 'N', one(T), parent(parent(A)), parent(B), parent(X))
484+
X
485+
end
486+
487+
function LinearAlgebra.mul!(X::DenseCuMatrix{T},
488+
A::LowerTriangular{<:Any,<:$trtype{T,<:DenseCuMatrix}},
489+
B::UpperTriangular{T,<:DenseCuMatrix},
490+
) where {T<:$valtype}
491+
triu!(parent(B))
492+
CUBLAS.trmm!('L', 'U', 'T', 'N', one(T), parent(parent(A)), parent(B), parent(X))
493+
X
494+
end
495+
496+
function LinearAlgebra.mul!(X::DenseCuMatrix{T},
497+
A::LowerTriangular{T,<:DenseCuMatrix},
498+
B::UpperTriangular{<:Any,<:$trtype{T,<:DenseCuMatrix}},
499+
) where {T<:$valtype}
500+
# operation is reversed to avoid executing the tranpose
501+
tril!(parent(A))
502+
CUBLAS.trmm!('R', 'L', 'T', 'N', one(T), parent(parent(B)), parent(A), parent(X))
503+
X
504+
end
505+
end
506+
end
507+
446508
# symmetric mul!
447509
# level 2
448510
@inline function LinearAlgebra.mul!(y::CuVector{T}, A::Hermitian{T,<:CuMatrix}, x::CuVector{T},

test/cublas.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ end
10191019
end
10201020
end
10211021

1022-
@testset "triangular mul!" begin
1022+
@testset "triangular-dense mul!" begin
10231023
A = triu(rand(elty, m, m))
10241024
B = rand(elty,m,n)
10251025
C = zeros(elty,m,n)
@@ -1052,6 +1052,44 @@ end
10521052
end
10531053
end
10541054

1055+
@testset "triangular-triangular mul!" begin
1056+
A = triu(rand(elty, m, m))
1057+
B = triu(rand(elty, m, m))
1058+
C0 = zeros(elty,m,m)
1059+
1060+
sA = rand(elty,m,m)
1061+
sA = sA + transpose(sA)
1062+
sB = rand(elty,m,m)
1063+
sB = sB + transpose(sB)
1064+
1065+
for (TRa, ta, TRb, tb, TRc) in (
1066+
(UpperTriangular, identity, LowerTriangular, identity, Matrix),
1067+
(LowerTriangular, identity, UpperTriangular, identity, Matrix),
1068+
(UpperTriangular, identity, UpperTriangular, transpose, Matrix),
1069+
(UpperTriangular, transpose, UpperTriangular, identity, Matrix),
1070+
(LowerTriangular, identity, LowerTriangular, transpose, Matrix),
1071+
(LowerTriangular, transpose, LowerTriangular, identity, Matrix),
1072+
)
1073+
1074+
A = copy(sA) |> TRa
1075+
B = copy(sB) |> TRb
1076+
C = copy(C0) |> TRc
1077+
dA = CuArray(parent(sA)) |> TRa
1078+
dB = CuArray(parent(sB)) |> TRb
1079+
dC = if TRc == Matrix
1080+
CuArray(C0) |> DenseCuMatrix
1081+
else
1082+
CuArray(C0) |> TRc
1083+
end
1084+
1085+
D = mul!(C, ta(A), tb(B))
1086+
dD = mul!(dC, ta(dA), tb(dB))
1087+
1088+
@test C Array(dC)
1089+
@test D Array(dD)
1090+
end
1091+
end
1092+
10551093
B = rand(elty,m,n)
10561094
C = rand(elty,m,n)
10571095
d_B = CuArray(B)

0 commit comments

Comments
 (0)