Skip to content

Commit 85919e6

Browse files
authored
Merge pull request #1108 from JuliaLang/jishnub/tri_muldiv_stride
Non-contiguous matrices in triangular mul! and div!
2 parents aecb714 + cd0da66 commit 85919e6

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

src/triangular.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,15 +1205,35 @@ end
12051205
# multiplication
12061206
generic_trimatmul!(c::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, b::AbstractVector{T}) where {T<:BlasFloat} =
12071207
BLAS.trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b))
1208-
generic_trimatmul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat} =
1209-
BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
1210-
generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} =
1211-
BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1208+
function generic_trimatmul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat}
1209+
if stride(C,1) == stride(A,1) == 1
1210+
BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
1211+
else # incompatible with BLAS
1212+
@invoke generic_trimatmul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
1213+
end
1214+
end
1215+
function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
1216+
if stride(C,1) == stride(B,1) == 1
1217+
BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1218+
else # incompatible with BLAS
1219+
@invoke generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
1220+
end
1221+
end
12121222
# division
1213-
generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat} =
1214-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1215-
generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} =
1216-
BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1223+
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
1224+
if stride(C,1) == stride(A,1) == 1
1225+
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1226+
else # incompatible with LAPACK
1227+
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
1228+
end
1229+
end
1230+
function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
1231+
if stride(C,1) == stride(B,1) == 1
1232+
BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1233+
else # incompatible with BLAS
1234+
@invoke generic_mattridiv!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
1235+
end
1236+
end
12171237

12181238
errorbounds(A::AbstractTriangular{T}, X::AbstractVecOrMat{T}, B::AbstractVecOrMat{T}) where {T<:Union{BigFloat,Complex{BigFloat}}} =
12191239
error("not implemented yet! Please submit a pull request.")

test/triangular.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,4 +1427,17 @@ end
14271427
end
14281428
end
14291429

1430+
@testset "(l/r)mul! and (l/r)div! for non-contiguous matrices" begin
1431+
U = UpperTriangular(reshape(collect(3:27.0),5,5))
1432+
B = float.(collect(reshape(1:100, 10,10)))
1433+
B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v)
1434+
@test lmul!(U, B2v) == lmul!(U, B2vc)
1435+
B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v)
1436+
@test rmul!(B2v, U) == rmul!(B2vc, U)
1437+
B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v)
1438+
@test ldiv!(U, B2v) ldiv!(U, B2vc)
1439+
B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v)
1440+
@test rdiv!(B2v, U) rdiv!(B2vc, U)
1441+
end
1442+
14301443
end # module TestTriangular

0 commit comments

Comments
 (0)