Skip to content

Commit f0f7a46

Browse files
authored
Backports release 1.12 (#1217)
Backported PRs: - [x] #1194 - [x] #1207 - [x] #1196 <!-- Explicitly declare type constructor imports --> - [x] #1202 <!-- Add fast path in generic matmul --> - [x] #1203 <!-- Restrict Diagonal sqrt branch to positive diag --> - [x] #1210 <!-- Indirection in matrix multiplication to avoid ambiguities -->
2 parents e6a8ba5 + 6b5bdf2 commit f0f7a46

File tree

11 files changed

+79
-23
lines changed

11 files changed

+79
-23
lines changed

src/LinearAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1616
permutedims, permuterows!, power_by_squaring, promote_rule, real, sec, sech, setindex!,
1717
show, similar, sin, sincos, sinh, size, sqrt, strides, stride, tan, tanh, transpose, trunc,
1818
typed_hcat, vec, view, zero
19+
import Base: AbstractArray, AbstractMatrix, Array, Matrix
1920
using Base: IndexLinear, promote_eltype, promote_op, print_matrix,
2021
@propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
2122
splat, BitInteger

src/adjtrans.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ const AdjointAbsVec{T} = Adjoint{T,<:AbstractVector}
319319
const AdjointAbsMat{T} = Adjoint{T,<:AbstractMatrix}
320320
const TransposeAbsVec{T} = Transpose{T,<:AbstractVector}
321321
const TransposeAbsMat{T} = Transpose{T,<:AbstractMatrix}
322-
const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector}
323-
const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix}
322+
const AdjOrTransAbsVec{T,V<:AbstractVector} = AdjOrTrans{T,V}
323+
const AdjOrTransAbsMat{T,M<:AbstractMatrix} = AdjOrTrans{T,M}
324324

325325
# for internal use below
326326
wrapperop(_) = identity

src/blas.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ export
8484
trsm!,
8585
trsm
8686

87-
using ..LinearAlgebra: libblastrampoline, BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, checksquare, chkstride1
87+
using ..LinearAlgebra: libblastrampoline, BlasReal, BlasComplex, BlasFloat, BlasInt,
88+
DimensionMismatch, checksquare, chkstride1, SingularException
8889

8990
include("lbt.jl")
9091

@@ -1369,6 +1370,11 @@ for (fname, elty) in ((:dtrsv_,:Float64),
13691370
throw(DimensionMismatch(lazy"size of A is $n != length(x) = $(length(x))"))
13701371
end
13711372
chkstride1(A)
1373+
if diag == 'N'
1374+
for i in 1:n
1375+
iszero(A[i,i]) && throw(SingularException(i))
1376+
end
1377+
end
13721378
px, stx = vec_pointer_stride(x, ArgumentError("input vector with 0 stride is not allowed"))
13731379
GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid,
13741380
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
@@ -2217,6 +2223,11 @@ for (mmname, smname, elty) in
22172223
end
22182224
chkstride1(A)
22192225
chkstride1(B)
2226+
if diag == 'N'
2227+
for i in 1:k
2228+
iszero(A[i,i]) && throw(SingularException(i))
2229+
end
2230+
end
22202231
ccall((@blasfunc($smname), libblastrampoline), Cvoid,
22212232
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8},
22222233
Ref{BlasInt}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},

src/dense.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,8 @@ sqrt(::AbstractMatrix)
972972
function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
973973
if checksquare(A) == 0
974974
return copy(float(A))
975-
elseif isdiag(A)
975+
elseif isdiag(A) && (T <: Complex || all(x -> x zero(x), diagview(A)))
976+
# Real Diagonal sqrt requires each diagonal element to be positive
976977
return applydiagonal(sqrt, A)
977978
elseif ishermitian(A)
978979
sqrtHermA = sqrt(Hermitian(A))

src/diagonal.jl

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,28 @@ function (*)(D::Diagonal, V::AbstractVector)
332332
return D.diag .* V
333333
end
334334

335+
function _diag_adj_mul(A::AdjOrTransAbsMat, D::Diagonal)
336+
adj = wrapperop(A)
337+
copy(adj(adj(D) * adj(A)))
338+
end
339+
function _diag_adj_mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number})
340+
@invoke *(A::AbstractMatrix, D::AbstractMatrix)
341+
end
342+
function _diag_adj_mul(D::Diagonal, A::AdjOrTransAbsMat)
343+
adj = wrapperop(A)
344+
copy(adj(adj(A) * adj(D)))
345+
end
346+
function _diag_adj_mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix})
347+
@invoke *(D::AbstractMatrix, A::AbstractMatrix)
348+
end
349+
350+
function (*)(A::AdjOrTransAbsMat, D::Diagonal)
351+
_diag_adj_mul(A, D)
352+
end
353+
function (*)(D::Diagonal, A::AdjOrTransAbsMat)
354+
_diag_adj_mul(D, A)
355+
end
356+
335357
function rmul!(A::AbstractMatrix, D::Diagonal)
336358
matmul_size_check(size(A), size(D))
337359
for I in CartesianIndices(A)
@@ -671,22 +693,24 @@ end
671693
for Tri in (:UpperTriangular, :LowerTriangular)
672694
UTri = Symbol(:Unit, Tri)
673695
# 2 args
674-
for (fun, f) in zip((:*, :rmul!, :rdiv!, :/), (:identity, :identity, :inv, :inv))
675-
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
676-
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
696+
for (fun, f) in zip((:mul, :rmul!, :rdiv!, :/), (:identity, :identity, :inv, :inv))
697+
g = fun == :mul ? :* : fun
698+
@eval $fun(A::$Tri, D::Diagonal) = $Tri($g(A.data, D))
699+
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($g(A.data, D), $f, D.diag))
677700
end
678-
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
679-
@invoke *(A::AbstractMatrix, D::Diagonal)
680-
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
681-
@invoke *(A::AbstractMatrix, D::Diagonal)
682-
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
683-
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
684-
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
701+
@eval mul(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
702+
@invoke mul(A::AbstractMatrix, D::Diagonal)
703+
@eval mul(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
704+
@invoke mul(A::AbstractMatrix, D::Diagonal)
705+
for (fun, f) in zip((:mul, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
706+
g = fun == :mul ? :* : fun
707+
@eval $fun(D::Diagonal, A::$Tri) = $Tri($g(D, A.data))
708+
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($g(D, A.data), $f, D.diag))
685709
end
686-
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
687-
@invoke *(D::Diagonal, A::AbstractMatrix)
688-
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
689-
@invoke *(D::Diagonal, A::AbstractMatrix)
710+
@eval mul(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
711+
@invoke mul(D::Diagonal, A::AbstractMatrix)
712+
@eval mul(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
713+
@invoke mul(D::Diagonal, A::AbstractMatrix)
690714
# 3-arg ldiv!
691715
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
692716
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))

src/matmul.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ julia> [1 1; 0 1] * [1 0; 1 1]
111111
1 1
112112
```
113113
"""
114-
function (*)(A::AbstractMatrix, B::AbstractMatrix)
114+
(*)(A::AbstractMatrix, B::AbstractMatrix) = mul(A, B)
115+
# we add an extra level of indirection to avoid ambiguities in *
116+
function mul(A::AbstractMatrix, B::AbstractMatrix)
115117
TS = promote_op(matprod, eltype(A), eltype(B))
116118
mul!(matprod_dest(A, B, TS), A, B)
117119
end
@@ -1021,6 +1023,7 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
10211023
@inbounds for n in axes(B, 2), k in axes(B, 1)
10221024
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
10231025
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
1026+
!ismissing(Balpha) && iszero(Balpha) && continue
10241027
@simd for m in axes(A, 1)
10251028
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
10261029
end

src/triangular.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,11 +1223,13 @@ function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function,
12231223
end
12241224
end
12251225
# division
1226-
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
1226+
generic_trimatdiv!(C::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVector{T}) where {T<:BlasFloat} =
1227+
BLAS.trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1228+
function generic_trimatdiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat}
12271229
if stride(C,1) == stride(A,1) == 1
1228-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1230+
BLAS.trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
12291231
else # incompatible with LAPACK
1230-
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
1232+
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12311233
end
12321234
end
12331235
function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}

test/dense.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,12 @@ end
984984
@testset "sqrt for diagonal" begin
985985
A = diagm(0 => [1, 2, 3])
986986
@test sqrt(A)^2 A
987+
988+
A = diagm(0 => [1.0, -1.0])
989+
@test sqrt(A) == diagm(0 => [1.0, 1.0im])
990+
@test sqrt(A)^2 A
991+
B = im*A
992+
@test sqrt(B)^2 B
987993
end
988994

989995
@testset "issue #40141" begin

test/matmul.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ import LinearAlgebra: Adjoint, Transpose
767767
(*)(x::RootInt, y::Integer) = x.i * y
768768
adjoint(x::RootInt) = x
769769
transpose(x::RootInt) = x
770+
Base.zero(::RootInt) = RootInt(0)
770771

771772
@test Base.promote_op(*, RootInt, RootInt) === Int
772773

test/testtriag.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,8 @@ function test_triangular(elty1_types)
493493
@test_throws DimensionMismatch transpose(Ann) \ bm
494494
if t1 == UpperTriangular || t1 == LowerTriangular
495495
@test_throws SingularException ldiv!(t1(zeros(elty1, n, n)), fill(eltyB(1), n))
496+
@test_throws SingularException ldiv!(t1(zeros(elty1, n, n)), fill(eltyB(1), n, 2))
497+
@test_throws SingularException rdiv!(fill(eltyB(1), n, n), t1(zeros(elty1, n, n)))
496498
end
497499
@test B / A1 B / M1
498500
@test B / transpose(A1) B / transpose(M1)

0 commit comments

Comments
 (0)