Skip to content

Commit e389738

Browse files
amontoisonmaleadt
authored andcommitted
[oneMKL] Interface variants of trsm! and trmm!
1 parent fc225b0 commit e389738

File tree

3 files changed

+139
-7
lines changed

3 files changed

+139
-7
lines changed

lib/mkl/linalg.jl

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LinearAlgebra: Transpose, Adjoint,
55
Hermitian, Symmetric,
66
LowerTriangular, UnitLowerTriangular,
77
UpperTriangular, UnitUpperTriangular,
8-
MulAddMul, wrap
8+
UpperOrLowerTriangular, MulAddMul, wrap
99

1010
#
1111
# BLAS 1
@@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
163163
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
164164
end
165165

166+
const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}}
167+
168+
function LinearAlgebra.generic_trimatmul!(
169+
C::oneStridedMatrix{T}, uplocA, isunitcA,
170+
tfunA::Function, A::oneStridedMatrix{T},
171+
triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}},
172+
) where {T<:onemklFloat}
173+
uplocB = LinearAlgebra.uplo_char(triB)
174+
isunitcB = LinearAlgebra.isunit_char(triB)
175+
B = parent(triB)
176+
tfunB = LinearAlgebra.wrapperop(B)
177+
transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C'
178+
transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C'
179+
if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper
180+
triu!(B)
181+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
182+
elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower
183+
tril!(B)
184+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
185+
elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N'
186+
# operation is reversed to avoid executing the tranpose
187+
triu!(A)
188+
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
189+
elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N'
190+
tril!(B)
191+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
192+
elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N'
193+
triu!(B)
194+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
195+
elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N'
196+
tril!(A)
197+
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
198+
else
199+
throw("mixed triangular-triangular multiplication") # TODO: rethink
200+
end
201+
return C
202+
end
203+
166204
# triangular
167205
LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
168-
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
206+
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
169207
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
170-
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
171-
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
172-
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
173-
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
174-
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
208+
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
209+
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} =
210+
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
211+
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
212+
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)

lib/mkl/wrappers_blas.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,76 @@ function trsm(side::Char,
11391139
trsm!(side, uplo, transa, diag, alpha, A, copy(B))
11401140
end
11411141

1142+
for (mmname_variant, smname_variant, elty) in
1143+
((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
1144+
(:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
1145+
(:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
1146+
(:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32))
1147+
@eval begin
1148+
function trmm!(side::Char,
1149+
uplo::Char,
1150+
transa::Char,
1151+
diag::Char,
1152+
alpha::Number,
1153+
beta::Number,
1154+
A::oneStridedMatrix{$elty},
1155+
B::oneStridedMatrix{$elty},
1156+
C::oneStridedMatrix{$elty})
1157+
m, n = size(B)
1158+
mA, nA = size(A)
1159+
if mA != nA throw(DimensionMismatch("A must be square")) end
1160+
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
1161+
lda = max(1,stride(A,2))
1162+
ldb = max(1,stride(B,2))
1163+
ldc = max(1,stride(C,2))
1164+
queue = global_queue(context(A), device())
1165+
$mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, c, ldc)
1166+
B
1167+
end
1168+
1169+
function trsm!(side::Char,
1170+
uplo::Char,
1171+
transa::Char,
1172+
diag::Char,
1173+
alpha::Number,
1174+
beta::Number,
1175+
A::oneStridedMatrix{$elty},
1176+
B::oneStridedMatrix{$elty},
1177+
C::oneStridedMatrix{$elty})
1178+
m, n = size(B)
1179+
mA, nA = size(A)
1180+
if mA != nA throw(DimensionMismatch("A must be square")) end
1181+
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
1182+
lda = max(1,stride(A,2))
1183+
ldb = max(1,stride(B,2))
1184+
ldc = max(1,stride(C,2))
1185+
queue = global_queue(context(A), device())
1186+
$smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, c, ldc)
1187+
B
1188+
end
1189+
end
1190+
end
1191+
function trmm!(side::Char,
1192+
uplo::Char,
1193+
transa::Char,
1194+
diag::Char,
1195+
alpha::Number,
1196+
A::oneStridedMatrix{$elty},
1197+
B::oneStridedMatrix{$elty},
1198+
C::oneStridedMatrix{$elty}) where T
1199+
trmm!(side, uplo, transa, diag, alpha, one(T), A, B, C)
1200+
end
1201+
function trsm!(side::Char,
1202+
uplo::Char,
1203+
transa::Char,
1204+
diag::Char,
1205+
alpha::Number,
1206+
A::oneStridedMatrix{$elty},
1207+
B::oneStridedMatrix{$elty},
1208+
C::oneStridedMatrix{$elty}) where T
1209+
trsm!(side, uplo, transa, diag, alpha, one(T), A, B, C)
1210+
end
1211+
11421212
## hemm
11431213
for (fname, elty) in ((:onemklZhemm,:ComplexF64),
11441214
(:onemklChemm,:ComplexF32))

test/onemkl.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,14 @@ end
661661
# move to host and compare
662662
h_C = Array(dB)
663663
@test C h_C
664+
665+
C = rand(T,m,n)
666+
dC = oneArray(C)
667+
beta = rand(T)
668+
oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC)
669+
h_C = Array(dC)
670+
D = alpha*A*B + beta*C
671+
@test D h_C
664672
end
665673

666674
@testset "trmm" begin
@@ -684,6 +692,14 @@ end
684692
dC = copy(dB)
685693
oneMKL.trsm!('L','U','N','N',alpha,dA,dC)
686694
@test C Array(dC)
695+
696+
C = rand(T,m,n)
697+
dC = oneArray(C)
698+
beta = rand(T)
699+
oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC)
700+
h_C = Array(dC)
701+
D = alpha*(A\B) + beta*C
702+
@test D h_C
687703
end
688704

689705
@testset "left trsm" begin
@@ -725,6 +741,14 @@ end
725741
dC = copy(dA)
726742
oneMKL.trsm!('R','U','N','N',alpha,dB,dC)
727743
@test C Array(dC)
744+
745+
C = rand(T,m,n)
746+
dC = oneArray(C)
747+
beta = rand(T)
748+
oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC)
749+
h_C = Array(dC)
750+
D = alpha*(A/B) + beta*C
751+
@test D h_C
728752
end
729753

730754
@testset "right trsm" begin

0 commit comments

Comments
 (0)