Skip to content

Commit 8d9b14f

Browse files
authored
Higher-order functions in factorize to obtain structure (#1135)
Instead of looping over the entire array to check the matrix structure, we may use higher-order querying functions like `istriu` and `issymmetric`. The advantage of this is that matrices might have optimized methods for these functions. We still loop over the entire matrix for `StridedMatrix`es as before, and obtain the structure in one-pass. The performance therefore isn't impacted in this case. An example with a sparse array wrapper: ```julia julia> using LinearAlgebra julia> struct MyMatrix{T,M<:AbstractMatrix{T}} <: AbstractMatrix{T} A::M end julia> Base.size(M::MyMatrix) = size(M.A) julia> Base.getindex(M::MyMatrix, i::Int, j::Int) = M.A[i, j] julia> LinearAlgebra.istriu(M::MyMatrix, k::Integer=0) = istriu(M.A, k) julia> LinearAlgebra.istril(M::MyMatrix, k::Integer=0) = istril(M.A, k) julia> LinearAlgebra.issymmetric(M::MyMatrix) = issymmetric(M.A) julia> LinearAlgebra.ishermitian(M::MyMatrix) = ishermitian(M.A) julia> using SparseArrays julia> S = sparse(1:4000, 1:4000, 1:4000); julia> M = MyMatrix(S); julia> @Btime factorize($M); 178.231 ms (4 allocations: 31.34 KiB) # master 22.165 ms (10 allocations: 94.04 KiB) # this PR ```
1 parent 7c0ecd6 commit 8d9b14f

File tree

2 files changed

+64
-31
lines changed

2 files changed

+64
-31
lines changed

src/dense.jl

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,38 +1541,8 @@ function factorize(A::AbstractMatrix{T}) where T
15411541
m, n = size(A)
15421542
if m == n
15431543
if m == 1 return A[1] end
1544-
utri = true
1545-
utri1 = true
1546-
herm = true
1547-
sym = true
1548-
for j = 1:n-1, i = j+1:m
1549-
if utri1
1550-
if A[i,j] != 0
1551-
utri1 = i == j + 1
1552-
utri = false
1553-
end
1554-
end
1555-
if sym
1556-
sym &= A[i,j] == A[j,i]
1557-
end
1558-
if herm
1559-
herm &= A[i,j] == conj(A[j,i])
1560-
end
1561-
if !(utri1|herm|sym) break end
1562-
end
1563-
ltri = true
1564-
ltri1 = true
1565-
for j = 3:n, i = 1:j-2
1566-
ltri1 &= A[i,j] == 0
1567-
if !ltri1 break end
1568-
end
1544+
utri, utri1, ltri, ltri1, sym, herm = getstructure(A)
15691545
if ltri1
1570-
for i = 1:n-1
1571-
if A[i,i+1] != 0
1572-
ltri &= false
1573-
break
1574-
end
1575-
end
15761546
if ltri
15771547
if utri
15781548
return Diagonal(A)
@@ -1610,6 +1580,66 @@ factorize(A::Adjoint) = adjoint(factorize(parent(A)))
16101580
factorize(A::Transpose) = transpose(factorize(parent(A)))
16111581
factorize(a::Number) = a # same as how factorize behaves on Diagonal types
16121582

1583+
function getstructure(A::StridedMatrix)
1584+
m, n = size(A)
1585+
if m == 1 return A[1] end
1586+
utri = true
1587+
utri1 = true
1588+
herm = true
1589+
sym = true
1590+
for j = 1:n-1, i = j+1:m
1591+
if utri1
1592+
if A[i,j] != 0
1593+
utri1 = i == j + 1
1594+
utri = false
1595+
end
1596+
end
1597+
if sym
1598+
sym &= A[i,j] == A[j,i]
1599+
end
1600+
if herm
1601+
herm &= A[i,j] == conj(A[j,i])
1602+
end
1603+
if !(utri1|herm|sym) break end
1604+
end
1605+
ltri = true
1606+
ltri1 = true
1607+
for j = 3:n, i = 1:j-2
1608+
ltri1 &= A[i,j] == 0
1609+
if !ltri1 break end
1610+
end
1611+
if ltri1
1612+
for i = 1:n-1
1613+
if A[i,i+1] != 0
1614+
ltri &= false
1615+
break
1616+
end
1617+
end
1618+
end
1619+
return (utri, utri1, ltri, ltri1, sym, herm)
1620+
end
1621+
_check_sym_herm(A) = (issymmetric(A), ishermitian(A))
1622+
_check_sym_herm(A::AbstractMatrix{<:Real}) = (sym = issymmetric(A); (sym,sym))
1623+
function getstructure(A::AbstractMatrix)
1624+
utri1 = istriu(A,-1)
1625+
# utri = istriu(A), but since we've already checked istriu(A,-1),
1626+
# we only need to check that the subdiagonal band is zero
1627+
utri = utri1 && iszero(diag(A,-1))
1628+
sym, herm = _check_sym_herm(A)
1629+
if sym || herm
1630+
# in either case, the lower and upper triangular halves have identical band structures
1631+
# in this case, istril(A,1) == istriu(A,-1) and istril(A) == istriu(A)
1632+
ltri1 = utri1
1633+
ltri = utri
1634+
else
1635+
ltri1 = istril(A,1)
1636+
# ltri = istril(A), but since we've already checked istril(A,1),
1637+
# we only need to check the superdiagonal band is zero
1638+
ltri = ltri1 && iszero(diag(A,1))
1639+
end
1640+
return (utri, utri1, ltri, ltri1, sym, herm)
1641+
end
1642+
16131643
## Moore-Penrose pseudoinverse
16141644

16151645
"""

test/dense.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module TestDense
44

55
using Test, LinearAlgebra, Random
66
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal
7+
using Test: GenericArray
78

89
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
910
isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl"))
@@ -191,6 +192,8 @@ bimg = randn(n,2)/2
191192
f = rand(eltya,n-2)
192193
A = diagm(0 => d)
193194
@test factorize(A) == Diagonal(d)
195+
# test that the generic structure-evaluation method works
196+
@test factorize(A) == factorize(GenericArray(A))
194197
A += diagm(-1 => e)
195198
@test factorize(A) == Bidiagonal(d,e,:L)
196199
A += diagm(-2 => f)

0 commit comments

Comments
 (0)