Skip to content

Commit af84b51

Browse files
authored
Update BLAS & add more BLAS tests (#602)
1 parent 44e5818 commit af84b51

File tree

4 files changed

+228
-176
lines changed

4 files changed

+228
-176
lines changed

src/blas/highlevel.jl

Lines changed: 162 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,14 @@ end
192192
# BLAS 3
193193
#
194194

195-
########
196-
# GEMM
197-
########
198-
function gemm_wrapper!(
199-
C::ROCVecOrMat{T}, tA::Char, tB::Char,
200-
A::ROCVecOrMat{T}, B::ROCVecOrMat{T},
201-
alpha = one(T), beta = zero(T),
202-
) where T <: ROCBLASFloatWithHalf
203-
mA, nA = rocblas_size(tA, A)
204-
mB, nB = rocblas_size(tB, B)
195+
function LinearAlgebra.generic_matmatmul!(
196+
C::StridedROCVecOrMat, tA, tB, A::StridedROCVecOrMat,
197+
B::StridedROCVecOrMat, _add::MulAddMul,
198+
)
199+
T = eltype(C)
200+
alpha, beta = _add.alpha, _add.beta
201+
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
202+
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)
205203

206204
nA != mB && throw(DimensionMismatch(
207205
"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
@@ -213,129 +211,124 @@ function gemm_wrapper!(
213211
"C has dimensions $(size(C)), should have ($mA,$nB)"))
214212
return LinearAlgebra.rmul!(C, 0)
215213
end
216-
gemm!(tA, tB, alpha, A, B, beta, C)
214+
215+
if alpha isa Union{Bool, T} && beta isa Union{Bool, T}
216+
α, β = T(alpha), T(beta)
217+
218+
if (
219+
all(in(('N', 'T', 'C')), (tA, tB)) && T <: ROCBLASFloat &&
220+
A isa StridedROCArray{T} && B isa StridedROCArray{T}
221+
)
222+
return gemm!(tA, tB, α, A, B, β, C)
223+
elseif (tA == 'S' || tA == 's') && tB == 'N'
224+
return symm!('L', tA == 'S' ? 'U' : 'L', α, A, B, β, C)
225+
elseif (tB == 'S' || tB == 's') && tA == 'N'
226+
return symm!('R', tB == 'S' ? 'U' : 'L', α, B, A, β, C)
227+
elseif (tA == 'H' || tA == 'h') && tB == 'N'
228+
return hemm!('L', tA == 'H' ? 'U' : 'L', α, A, B, β, C)
229+
elseif (tB == 'H' || tB == 'h') && tA == 'N'
230+
return hemm!('R', tB == 'H' ? 'U' : 'L', α, B, A, β, C)
231+
end
232+
end
233+
234+
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
217235
end
218236

219-
# Mutating
220-
LinearAlgebra.mul!(C::ROCMatrix{T}, A::ROCVecOrMat{T}, B::ROCVecOrMat{T}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
221-
gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta)
222-
LinearAlgebra.mul!(C::ROCMatrix{T}, trA::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, B::ROCMatrix{T}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
223-
gemm_wrapper!(C, 'T', 'N', parent(trA), B, alpha, beta)
224-
LinearAlgebra.mul!(C::ROCMatrix{T}, A::ROCMatrix{T}, trB::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
225-
gemm_wrapper!(C, 'N', 'T', A, parent(trB), alpha, beta)
226-
LinearAlgebra.mul!(C::ROCMatrix{T}, trA::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, trB::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
227-
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(trB), alpha, beta)
228-
LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, ROCMatrix{T}}, B::ROCMatrix{T}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASReal =
229-
gemm_wrapper!(C, 'T', 'N', parent(adjA), B, alpha, beta)
230-
LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, <:ROCMatrix{T}}, B::ROCMatrix{T}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
231-
gemm_wrapper!(C, 'C', 'N', parent(adjA), B, alpha, beta)
232-
LinearAlgebra.mul!(C::ROCMatrix{T}, A::ROCMatrix{T}, adjB::LinearAlgebra.Adjoint{<:Any, ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASReal =
233-
gemm_wrapper!(C, 'N', 'T', A, parent(adjB), alpha, beta)
234-
LinearAlgebra.mul!(C::ROCMatrix{T}, A::ROCMatrix{T}, adjB::LinearAlgebra.Adjoint{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
235-
gemm_wrapper!(C, 'N', 'C', A, parent(adjB), alpha, beta)
236-
LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, ROCMatrix{T}}, adjB::LinearAlgebra.Adjoint{<:Any, ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASReal =
237-
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(adjB), alpha, beta)
238-
LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, <:ROCMatrix{T}}, adjB::LinearAlgebra.Adjoint{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
239-
gemm_wrapper!(C, 'C', 'C', parent(adjA), parent(adjB), alpha, beta)
240-
LinearAlgebra.mul!(C::ROCMatrix{T}, trA::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, adjB::LinearAlgebra.Adjoint{T, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASReal =
241-
gemm_wrapper!(C, 'T', 'T', parent(trA), parent(adjB), alpha, beta)
242-
LinearAlgebra.mul!(C::ROCMatrix{T}, trA::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, adjB::LinearAlgebra.Adjoint{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASFloatWithHalf =
243-
gemm_wrapper!(C, 'T', 'C', parent(trA), parent(adjB), alpha, beta)
244-
LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{T, <:ROCMatrix{T}}, trB::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T<:ROCBLASReal =
245-
gemm_wrapper!(C, 'T', 'T', parent(adjA), parent(trB), alpha, beta)
246-
LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, <:ROCMatrix{T}}, trB::LinearAlgebra.Transpose{<:Any, <:ROCMatrix{T}}, alpha::T = one(T), beta::T = zero(T)) where T <: ROCBLASFloatWithHalf =
247-
gemm_wrapper!(C, 'C', 'T', parent(adjA), parent(trB), alpha, beta)
248-
249-
250-
########
251-
# TRSM
252-
########
253-
254-
# TODO requires ROCm 5.6 for out-of-place trmm:
255-
# https://rocm.docs.amd.com/projects/rocBLAS/en/latest/API_Reference_Guide.html#announced-in-rocblas-3-0
256-
#
257-
# if VERSION ≥ v"1.10-"
258-
# LinearAlgebra.generic_trimatmul!(
259-
# C::ROCMatrix{T}, uploc, isunitc, tfun::Function,
260-
# A::ROCMatrix{T}, B::ROCMatrix{T},
261-
# ) where T <: ROCBLASFloat = trmm!(
262-
# 'L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
263-
# isunitc, one(T), A, B, C)
264-
265-
# LinearAlgebra.generic_mattrimul!(
266-
# C::ROCMatrix{T}, uploc, isunitc, tfun::Function,
267-
# A::ROCMatrix{T}, B::ROCMatrix{T},
268-
# ) where T <: ROCBLASFloat = trmm!(
269-
# 'R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
270-
# isunitc, one(T), B, A, C)
271-
272-
# # tri-tri-mul!
273-
# const AdjOrTransOrROCMatrix{T} = Union{ROCMatrix{T}, AdjOrTrans{<:T,<:ROCMatrix}}
274-
275-
# function LinearAlgebra.generic_trimatmul!(
276-
# C::ROCMatrix{T}, uplocA, isunitcA, tfunA::Function,
277-
# A::ROCMatrix{T}, triB::UpperOrLowerTriangular{T,<:AdjOrTransOrROCMatrix{T}},
278-
# ) where T <: ROCBLASFloat
279-
# uplocB = LinearAlgebra.uplo_char(triB)
280-
# isunitcB = LinearAlgebra.isunit_char(triB)
281-
# B = parent(triB)
282-
# tfunB = LinearAlgebra.wrapperop(B)
283-
# transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C'
284-
# transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C'
285-
# if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper
286-
# triu!(B)
287-
# trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
288-
# elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower
289-
# tril!(B)
290-
# trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
291-
# elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N'
292-
# # operation is reversed to avoid executing the tranpose
293-
# triu!(A)
294-
# trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
295-
# elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N'
296-
# tril!(B)
297-
# trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
298-
# elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N'
299-
# triu!(B)
300-
# trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
301-
# elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N'
302-
# tril!(A)
303-
# trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
304-
# else
305-
# throw("mixed triangular-triangular multiplication") # TODO: rethink
306-
# end
307-
# return C
308-
# end
309-
310-
# LinearAlgebra.generic_trimatdiv!(
311-
# C::ROCMatrix{T}, uploc, isunitc, tfun::Function,
312-
# A::ROCMatrix{T}, B::AbstractMatrix{T},
313-
# ) where T <: ROCBLASFloat = trsm!(
314-
# 'L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
315-
# isunitc, one(T), A, C === B ? C : copyto!(C, B))
316-
317-
# LinearAlgebra.generic_mattridiv!(
318-
# C::ROCMatrix{T}, uploc, isunitc, tfun::Function,
319-
# A::AbstractMatrix{T}, B::ROCMatrix{T},
320-
# ) where T <: ROCBLASFloat = trsm!(
321-
# 'R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
322-
# isunitc, one(T), B, C === A ? C : copyto!(C, A))
323-
# else
237+
if VERSION v"1.10-"
238+
LinearAlgebra.generic_trimatmul!(
239+
C::StridedROCMatrix{T}, uploc, isunitc, tfun::Function,
240+
A::StridedROCMatrix{T}, B::StridedROCMatrix{T},
241+
) where T <: ROCBLASFloat = trmm!(
242+
'L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
243+
isunitc, one(T), A, B, C)
244+
245+
LinearAlgebra.generic_mattrimul!(
246+
C::StridedROCMatrix{T}, uploc, isunitc, tfun::Function,
247+
A::StridedROCMatrix{T}, B::StridedROCMatrix{T},
248+
) where T <: ROCBLASFloat = trmm!(
249+
'R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
250+
isunitc, one(T), B, A, C)
251+
252+
const AdjOrTransOrROCMatrix{T} = Union{
253+
StridedROCMatrix{T}, AdjOrTrans{<: T, <: StridedROCMatrix}}
254+
255+
function LinearAlgebra.generic_trimatmul!(
256+
C::StridedROCMatrix{T}, uplocA, isunitcA,
257+
tfunA::Function, A::StridedROCMatrix{T},
258+
triB::UpperOrLowerTriangular{T, <: AdjOrTransOrROCMatrix{T}},
259+
) where T <: ROCBLASFloat
260+
uplocB = LinearAlgebra.uplo_char(triB)
261+
isunitcB = LinearAlgebra.isunit_char(triB)
262+
B = parent(triB)
263+
tfunB = LinearAlgebra.wrapperop(B)
264+
transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C'
265+
transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C'
266+
if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper
267+
triu!(B)
268+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
269+
elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower
270+
tril!(B)
271+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
272+
elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N'
273+
# operation is reversed to avoid executing the tranpose
274+
triu!(A)
275+
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
276+
elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N'
277+
tril!(B)
278+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
279+
elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N'
280+
triu!(B)
281+
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
282+
elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N'
283+
tril!(A)
284+
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
285+
else
286+
throw("mixed triangular-triangular multiplication") # TODO: rethink
287+
end
288+
return C
289+
end
290+
291+
LinearAlgebra.generic_trimatdiv!(
292+
C::StridedROCMatrix{T}, uploc, isunitc, tfun::Function,
293+
A::StridedROCMatrix{T}, B::AbstractMatrix{T},
294+
) where T <: ROCBLASFloat = trsm!(
295+
'L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
296+
isunitc, one(T), A, C === B ? C : copyto!(C, B))
297+
298+
LinearAlgebra.generic_mattridiv!(
299+
C::StridedROCMatrix{T}, uploc, isunitc, tfun::Function,
300+
A::AbstractMatrix{T}, B::StridedROCMatrix{T},
301+
) where T <: ROCBLASFloat = trsm!(
302+
'R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
303+
isunitc, one(T), B, C === A ? C : copyto!(C, A))
304+
else
324305
for (t, uploc, isunitc) in (
325306
(:LowerTriangular, 'L', 'N'),
326307
(:UnitLowerTriangular, 'L', 'U'),
327308
(:UpperTriangular, 'U', 'N'),
328309
(:UnitUpperTriangular, 'U', 'U'),
329310
)
330311
@eval begin
331-
LinearAlgebra.lmul!(A::$t{T, <: ROCMatrix}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
312+
LinearAlgebra.lmul!(
313+
A::$t{T, <: StridedROCMatrix},
314+
B::StridedROCMatrix{T},
315+
) where T <: ROCBLASFloat =
332316
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B, B)
333-
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{T, <: ROCMatrix}) where T <: ROCBLASFloat =
317+
LinearAlgebra.rmul!(
318+
A::StridedROCMatrix{T},
319+
B::$t{T, <: StridedROCMatrix},
320+
) where T <: ROCBLASFloat =
334321
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A, A)
335322

336-
LinearAlgebra.ldiv!(A::$t{T, <: ROCMatrix}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
323+
LinearAlgebra.ldiv!(
324+
A::$t{T, <: StridedROCMatrix},
325+
B::StridedROCMatrix{T},
326+
) where T <: ROCBLASFloat =
337327
trsm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
338-
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{T, <: ROCMatrix}) where T <: ROCBLASFloat =
328+
LinearAlgebra.rdiv!(
329+
A::StridedROCMatrix{T},
330+
B::$t{T, <: StridedROCMatrix},
331+
) where T <: ROCBLASFloat =
339332
trsm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A)
340333
end
341334
end
@@ -349,35 +342,71 @@ LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, <:ROCMatr
349342
)
350343
@eval begin
351344
# Multiplication.
352-
LinearAlgebra.lmul!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
345+
LinearAlgebra.lmul!(
346+
A::$t{<: Any, <: Transpose{T, <: StridedROCMatrix}},
347+
B::StridedROCMatrix{T},
348+
) where T <: ROCBLASFloat =
353349
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B, B)
354-
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
350+
LinearAlgebra.lmul!(
351+
A::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
352+
B::StridedROCMatrix{T},
353+
) where T <: ROCBLASFloat =
355354
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B, B)
356-
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}, B::ROCMatrix{T}) where T <: ROCBLASComplex =
355+
LinearAlgebra.lmul!(
356+
A::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
357+
B::StridedROCMatrix{T},
358+
) where T <: ROCBLASComplex =
357359
trmm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B, B)
358360

359-
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Transpose{T, <: ROCMatrix}}) where T <: ROCBLASFloat =
361+
LinearAlgebra.rmul!(
362+
A::StridedROCMatrix{T},
363+
B::$t{<: Any, <: Transpose{T, <: StridedROCMatrix}},
364+
) where T <: ROCBLASFloat =
360365
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A, A)
361-
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}) where T <: ROCBLASFloat =
366+
LinearAlgebra.rmul!(
367+
A::StridedROCMatrix{T},
368+
B::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
369+
) where T <: ROCBLASFloat =
362370
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A, A)
363-
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}) where T <: ROCBLASComplex =
371+
LinearAlgebra.rmul!(
372+
A::StridedROCMatrix{T},
373+
B::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
374+
) where T <: ROCBLASComplex =
364375
trmm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A, A)
365376

366377
# Left division.
367-
LinearAlgebra.ldiv!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
378+
LinearAlgebra.ldiv!(
379+
A::$t{<: Any, <: Transpose{T, <: StridedROCMatrix}},
380+
B::StridedROCMatrix{T},
381+
) where T <: ROCBLASFloat =
368382
trsm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
369-
LinearAlgebra.ldiv!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
383+
LinearAlgebra.ldiv!(
384+
A::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
385+
B::StridedROCMatrix{T},
386+
) where T <: ROCBLASFloat =
370387
trsm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
371-
LinearAlgebra.ldiv!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}, B::ROCMatrix{T}) where T <: ROCBLASComplex =
388+
LinearAlgebra.ldiv!(
389+
A::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
390+
B::StridedROCMatrix{T},
391+
) where T <: ROCBLASComplex =
372392
trsm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B)
373393

374394
# Right division.
375-
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{<: Any, <: Transpose{T, <: ROCMatrix}}) where T <: ROCBLASFloat =
395+
LinearAlgebra.rdiv!(
396+
A::StridedROCMatrix{T},
397+
B::$t{<: Any, <: Transpose{T, <: StridedROCMatrix}},
398+
) where T <: ROCBLASFloat =
376399
trsm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
377-
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}) where T <: ROCBLASFloat =
400+
LinearAlgebra.rdiv!(
401+
A::StridedROCMatrix{T},
402+
B::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
403+
) where T <: ROCBLASFloat =
378404
trsm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
379-
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix}}) where T <: ROCBLASComplex =
405+
LinearAlgebra.rdiv!(
406+
A::StridedROCMatrix{T},
407+
B::$t{<: Any, <: Adjoint{T, <: StridedROCMatrix}},
408+
) where T <: ROCBLASComplex =
380409
trsm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A)
381410
end
382411
end
383-
# end
412+
end

src/blas/rocBLAS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ module rocBLAS
22

33
using ..AMDGPU
44
import AMDGPU: librocblas, AnyROCArray, StridedROCVector, StridedROCMatrix
5+
import AMDGPU: StridedROCVecOrMat, StridedROCArray
56
import AMDGPU: HandleCache, HIP, library_state
67
import .HIP: HIPContext, HIPStream, hipStream_t, hipEvent_t
78

9+
using GPUArrays
810
using LinearAlgebra
911
using LinearAlgebra: AdjOrTrans, MulAddMul
1012
if VERSION v"1.10-"

src/blas/wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ for (mmname, smname, elty) in
986986
$(mmname)(
987987
handle, side, uplo, transa, diag, m, n, Ref(alpha),
988988
A, lda, B, ldb, C, ldc) |> check
989-
B
989+
C
990990
end
991991
function trmm(
992992
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),

0 commit comments

Comments
 (0)