Skip to content

Commit 782490a

Browse files
jishnubdkarrasch
andauthored
Specialize isbanded for StridedMatrix (#56487)
This improves performance, as the loops in `istriu` and `istril` may be fused to improve cache-locality. This also changes the quick-return behavior, and only returns after the check over all the upper or lower bands for a column is complete. ```julia julia> using LinearAlgebra julia> A = zeros(2, 10_000); julia> @Btime isdiag($A); 32.682 μs (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1593" 9.481 μs (0 allocations: 0 bytes) # this PR julia> A = zeros(10_000, 2); julia> @Btime isdiag($A); 10.288 μs (0 allocations: 0 bytes) # nightly 2.579 μs (0 allocations: 0 bytes) # this PR julia> A = zeros(100, 100); julia> @Btime isdiag($A); 6.616 μs (0 allocations: 0 bytes) # nightly 3.075 μs (0 allocations: 0 bytes) # this PR julia> A = diagm(0=>1:100); A[3,4] = 1; julia> @Btime isdiag($A); 2.759 μs (0 allocations: 0 bytes) # nightly 85.371 ns (0 allocations: 0 bytes) # this PR ``` A similar change is added to `istriu`/`istril` as well, so that ```julia julia> A = zeros(2, 10_000); julia> @Btime istriu($A); # trivial 7.358 ns (0 allocations: 0 bytes) # nightly 13.779 ns (0 allocations: 0 bytes) # this PR julia> @Btime istril($A); 33.464 μs (0 allocations: 0 bytes) # nightly 9.476 μs (0 allocations: 0 bytes) # this PR julia> A = zeros(10_000, 2); julia> @Btime istriu($A); 10.020 μs (0 allocations: 0 bytes) # nightly 2.620 μs (0 allocations: 0 bytes) # this PR julia> @Btime istril($A); # trivial 6.793 ns (0 allocations: 0 bytes) # nightly 14.473 ns (0 allocations: 0 bytes) # this PR julia> A = zeros(100, 100); julia> @Btime istriu($A); 3.435 μs (0 allocations: 0 bytes) # nightly 1.637 μs (0 allocations: 0 bytes) # this PR julia> @Btime istril($A); 3.353 μs (0 allocations: 0 bytes) # nightly 1.661 μs (0 allocations: 0 bytes) # this PR ``` --------- Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
1 parent f2d16e4 commit 782490a

File tree

6 files changed

+170
-28
lines changed

6 files changed

+170
-28
lines changed

src/generic.jl

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,14 @@ end
13531353

13541354
ishermitian(x::Number) = (x == conj(x))
13551355

1356+
# helper function equivalent to `iszero(v)`, but potentially without the fast exit feature
1357+
# of `all` if this improves performance
1358+
_iszero(V) = iszero(V)
1359+
# A Base.FastContiguousSubArray view of a StridedArray
1360+
FastContiguousSubArrayStrided{T,N,P<:StridedArray,I<:Tuple{AbstractUnitRange, Vararg{Any}}} = Base.SubArray{T,N,P,I,true}
1361+
# using mapreduce instead of all permits vectorization
1362+
_iszero(V::FastContiguousSubArrayStrided) = mapreduce(iszero, &, V, init=true)
1363+
13561364
"""
13571365
istriu(A::AbstractMatrix, k::Integer = 0) -> Bool
13581366
@@ -1384,20 +1392,9 @@ julia> istriu(c, -1)
13841392
true
13851393
```
13861394
"""
1387-
function istriu(A::AbstractMatrix, k::Integer = 0)
1388-
require_one_based_indexing(A)
1389-
return _istriu(A, k)
1390-
end
1395+
istriu(A::AbstractMatrix, k::Integer = 0) = _isbanded_impl(A, k, size(A,2)-1)
13911396
istriu(x::Number) = true
13921397

1393-
@inline function _istriu(A::AbstractMatrix, k)
1394-
m, n = size(A)
1395-
for j in 1:min(n, m + k - 1)
1396-
all(iszero, view(A, max(1, j - k + 1):m, j)) || return false
1397-
end
1398-
return true
1399-
end
1400-
14011398
"""
14021399
istril(A::AbstractMatrix, k::Integer = 0) -> Bool
14031400
@@ -1429,20 +1426,9 @@ julia> istril(c, 1)
14291426
true
14301427
```
14311428
"""
1432-
function istril(A::AbstractMatrix, k::Integer = 0)
1433-
require_one_based_indexing(A)
1434-
return _istril(A, k)
1435-
end
1429+
istril(A::AbstractMatrix, k::Integer = 0) = _isbanded_impl(A, -size(A,1)+1, k)
14361430
istril(x::Number) = true
14371431

1438-
@inline function _istril(A::AbstractMatrix, k)
1439-
m, n = size(A)
1440-
for j in max(1, k + 2):n
1441-
all(iszero, view(A, 1:min(j - k - 1, m), j)) || return false
1442-
end
1443-
return true
1444-
end
1445-
14461432
"""
14471433
isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) -> Bool
14481434
@@ -1474,7 +1460,66 @@ julia> LinearAlgebra.isbanded(b, -1, 0)
14741460
true
14751461
```
14761462
"""
1477-
isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = istriu(A, kl) && istril(A, ku)
1463+
isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = _isbanded(A, kl, ku)
1464+
_isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = istriu(A, kl) && istril(A, ku)
1465+
# Performance optimization for StridedMatrix by better utilizing cache locality
1466+
# The istriu and istril loops are merged
1467+
# the additional indirection allows us to reuse the isbanded loop within istriu/istril
1468+
# without encountering cycles
1469+
_isbanded(A::StridedMatrix, kl::Integer, ku::Integer) = _isbanded_impl(A, kl, ku)
1470+
function _isbanded_impl(A, kl, ku)
1471+
Base.require_one_based_indexing(A)
1472+
1473+
#=
1474+
We split the column range into four possible groups, depending on the values of kl and ku.
1475+
1476+
The first is the bottom left triangle, where bands below kl must be zero,
1477+
but there are no bands above ku in that column.
1478+
1479+
The second is where there are both bands below kl and above ku in the column.
1480+
These are the middle columns typically.
1481+
1482+
The third is the top right, where there are bands above ku but no bands below kl
1483+
in the column.
1484+
1485+
The fourth is mainly relevant for wide matrices, where there is a block to the right
1486+
beyond ku, where the elements should all be zero. The reason we separate this from the
1487+
third group is that we may loop over all the rows using A[:, col] instead of A[rowrange, col],
1488+
which is usually faster.
1489+
=#
1490+
1491+
last_col_nonzeroblocks = size(A,1) + ku # fully zero rectangular block beyond this column
1492+
last_col_emptytoprows = ku + 1 # empty top rows before this column
1493+
last_col_nonemptybottomrows = size(A,1) + kl - 1 # empty bottom rows after this column
1494+
1495+
colrange_onlybottomrows = firstindex(A,2):min(last_col_nonemptybottomrows, last_col_emptytoprows)
1496+
colrange_topbottomrows = max(last_col_emptytoprows, last(colrange_onlybottomrows))+1:last_col_nonzeroblocks
1497+
colrange_onlytoprows_nonzero = last(colrange_topbottomrows)+1:last_col_nonzeroblocks
1498+
colrange_zero_block = last_col_nonzeroblocks+1:lastindex(A,2)
1499+
1500+
for col in intersect(axes(A,2), colrange_onlybottomrows) # only loop over the bottom rows
1501+
botrowinds = max(firstindex(A,1), col-kl+1):lastindex(A,1)
1502+
bottomrows = @view A[botrowinds, col]
1503+
_iszero(bottomrows) || return false
1504+
end
1505+
for col in intersect(axes(A,2), colrange_topbottomrows)
1506+
toprowinds = firstindex(A,1):min(col-ku-1, lastindex(A,1))
1507+
toprows = @view A[toprowinds, col]
1508+
_iszero(toprows) || return false
1509+
botrowinds = max(firstindex(A,1), col-kl+1):lastindex(A,1)
1510+
bottomrows = @view A[botrowinds, col]
1511+
_iszero(bottomrows) || return false
1512+
end
1513+
for col in intersect(axes(A,2), colrange_onlytoprows_nonzero)
1514+
toprowinds = firstindex(A,1):min(col-ku-1, lastindex(A,1))
1515+
toprows = @view A[toprowinds, col]
1516+
_iszero(toprows) || return false
1517+
end
1518+
for col in intersect(axes(A,2), colrange_zero_block)
1519+
_iszero(@view A[:, col]) || return false
1520+
end
1521+
return true
1522+
end
14781523

14791524
"""
14801525
isdiag(A) -> Bool

src/hessenberg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ Base.@constprop :aggressive function istriu(A::UpperHessenberg, k::Integer=0)
7777
k <= -1 && return true
7878
return _istriu(A, k)
7979
end
80+
# additional indirection to dispatch to optimized method for banded parents (defined in special.jl)
81+
@inline function _istriu(A::UpperHessenberg, k)
82+
P = parent(A)
83+
m = size(A, 1)
84+
for j in firstindex(P,2):min(m + k - 1, lastindex(P,2))
85+
Prows = @view P[max(begin, j - k + 1):min(j+1,end), j]
86+
_iszero(Prows) || return false
87+
end
88+
return true
89+
end
8090

8191
function Matrix{T}(H::UpperHessenberg) where T
8292
m,n = size(H)

src/special.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,4 @@ end
592592
# istriu/istril for triangular wrappers of structured matrices
593593
_istril(A::LowerTriangular{<:Any, <:BandedMatrix}, k) = istril(parent(A), k)
594594
_istriu(A::UpperTriangular{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k)
595+
_istriu(A::UpperHessenberg{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k)

src/triangular.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,25 +348,29 @@ Base.@constprop :aggressive function istril(A::LowerTriangular, k::Integer=0)
348348
k >= 0 && return true
349349
return _istril(A, k)
350350
end
351+
# additional indirection to dispatch to optimized method for banded parents (defined in special.jl)
351352
@inline function _istril(A::LowerTriangular, k)
352353
P = parent(A)
353354
for j in max(firstindex(P,2), k + 2):lastindex(P,2)
354-
all(iszero, @view(P[j:min(j - k - 1, end), j])) || return false
355+
_iszero(@view P[max(j, begin):min(j - k - 1, end), j]) || return false
355356
end
356357
return true
357358
end
359+
358360
Base.@constprop :aggressive function istriu(A::UpperTriangular, k::Integer=0)
359361
k <= 0 && return true
360362
return _istriu(A, k)
361363
end
364+
# additional indirection to dispatch to optimized method for banded parents (defined in special.jl)
362365
@inline function _istriu(A::UpperTriangular, k)
363366
P = parent(A)
364367
m = size(A, 1)
365368
for j in firstindex(P,2):min(m + k - 1, lastindex(P,2))
366-
all(iszero, @view(P[max(begin, j - k + 1):j, j])) || return false
369+
_iszero(@view P[max(begin, j - k + 1):min(j, end), j]) || return false
367370
end
368371
return true
369372
end
373+
370374
istril(A::Adjoint, k::Integer=0) = istriu(A.parent, -k)
371375
istril(A::Transpose, k::Integer=0) = istriu(A.parent, -k)
372376
istriu(A::Adjoint, k::Integer=0) = istril(A.parent, -k)

test/generic.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
module TestGeneric
44

55
using Test, LinearAlgebra, Random
6+
using Test: GenericArray
7+
using LinearAlgebra: isbanded
68

79
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
810

@@ -511,56 +513,110 @@ end
511513
end
512514

513515
@testset "generic functions for checking whether matrices have banded structure" begin
514-
using LinearAlgebra: isbanded
515516
pentadiag = [1 2 3; 4 5 6; 7 8 9]
516517
tridiag = [1 2 0; 4 5 6; 0 8 9]
518+
tridiagG = GenericArray([1 2 0; 4 5 6; 0 8 9])
519+
Tridiag = Tridiagonal(tridiag)
517520
ubidiag = [1 2 0; 0 5 6; 0 0 9]
521+
ubidiagG = GenericArray([1 2 0; 0 5 6; 0 0 9])
522+
uBidiag = Bidiagonal(ubidiag, :U)
518523
lbidiag = [1 0 0; 4 5 0; 0 8 9]
524+
lbidiagG = GenericArray([1 0 0; 4 5 0; 0 8 9])
525+
lBidiag = Bidiagonal(lbidiag, :L)
519526
adiag = [1 0 0; 0 5 0; 0 0 9]
527+
adiagG = GenericArray([1 0 0; 0 5 0; 0 0 9])
528+
aDiag = Diagonal(adiag)
520529
@testset "istriu" begin
521530
@test !istriu(pentadiag)
522531
@test istriu(pentadiag, -2)
523532
@test !istriu(tridiag)
533+
@test istriu(tridiag) == istriu(tridiagG) == istriu(Tridiag)
524534
@test istriu(tridiag, -1)
535+
@test istriu(tridiag, -1) == istriu(tridiagG, -1) == istriu(Tridiag, -1)
525536
@test istriu(ubidiag)
537+
@test istriu(ubidiag) == istriu(ubidiagG) == istriu(uBidiag)
526538
@test !istriu(ubidiag, 1)
539+
@test istriu(ubidiag, 1) == istriu(ubidiagG, 1) == istriu(uBidiag, 1)
527540
@test !istriu(lbidiag)
541+
@test istriu(lbidiag) == istriu(lbidiagG) == istriu(lBidiag)
528542
@test istriu(lbidiag, -1)
543+
@test istriu(lbidiag, -1) == istriu(lbidiagG, -1) == istriu(lBidiag, -1)
529544
@test istriu(adiag)
545+
@test istriu(adiag) == istriu(adiagG) == istriu(aDiag)
530546
end
531547
@testset "istril" begin
532548
@test !istril(pentadiag)
533549
@test istril(pentadiag, 2)
534550
@test !istril(tridiag)
551+
@test istril(tridiag) == istril(tridiagG) == istril(Tridiag)
535552
@test istril(tridiag, 1)
553+
@test istril(tridiag, 1) == istril(tridiagG, 1) == istril(Tridiag, 1)
536554
@test !istril(ubidiag)
555+
@test istril(ubidiag) == istril(ubidiagG) == istril(ubidiagG)
537556
@test istril(ubidiag, 1)
557+
@test istril(ubidiag, 1) == istril(ubidiagG, 1) == istril(uBidiag, 1)
538558
@test istril(lbidiag)
559+
@test istril(lbidiag) == istril(lbidiagG) == istril(lBidiag)
539560
@test !istril(lbidiag, -1)
561+
@test istril(lbidiag, -1) == istril(lbidiagG, -1) == istril(lBidiag, -1)
540562
@test istril(adiag)
563+
@test istril(adiag) == istril(adiagG) == istril(aDiag)
541564
end
542565
@testset "isbanded" begin
543566
@test isbanded(pentadiag, -2, 2)
544567
@test !isbanded(pentadiag, -1, 2)
545568
@test !isbanded(pentadiag, -2, 1)
546569
@test isbanded(tridiag, -1, 1)
570+
@test isbanded(tridiag, -1, 1) == isbanded(tridiagG, -1, 1) == isbanded(Tridiag, -1, 1)
547571
@test !isbanded(tridiag, 0, 1)
572+
@test isbanded(tridiag, 0, 1) == isbanded(tridiagG, 0, 1) == isbanded(Tridiag, 0, 1)
548573
@test !isbanded(tridiag, -1, 0)
574+
@test isbanded(tridiag, -1, 0) == isbanded(tridiagG, -1, 0) == isbanded(Tridiag, -1, 0)
549575
@test isbanded(ubidiag, 0, 1)
576+
@test isbanded(ubidiag, 0, 1) == isbanded(ubidiagG, 0, 1) == isbanded(uBidiag, 0, 1)
550577
@test !isbanded(ubidiag, 1, 1)
578+
@test isbanded(ubidiag, 1, 1) == isbanded(ubidiagG, 1, 1) == isbanded(uBidiag, 1, 1)
551579
@test !isbanded(ubidiag, 0, 0)
580+
@test isbanded(ubidiag, 0, 0) == isbanded(ubidiagG, 0, 0) == isbanded(uBidiag, 0, 0)
552581
@test isbanded(lbidiag, -1, 0)
582+
@test isbanded(lbidiag, -1, 0) == isbanded(lbidiagG, -1, 0) == isbanded(lBidiag, -1, 0)
553583
@test !isbanded(lbidiag, 0, 0)
584+
@test isbanded(lbidiag, 0, 0) == isbanded(lbidiagG, 0, 0) == isbanded(lBidiag, 0, 0)
554585
@test !isbanded(lbidiag, -1, -1)
586+
@test isbanded(lbidiag, -1, -1) == isbanded(lbidiagG, -1, -1) == isbanded(lBidiag, -1, -1)
555587
@test isbanded(adiag, 0, 0)
588+
@test isbanded(adiag, 0, 0) == isbanded(adiagG, 0, 0) == isbanded(aDiag, 0, 0)
556589
@test !isbanded(adiag, -1, -1)
590+
@test isbanded(adiag, -1, -1) == isbanded(adiagG, -1, -1) == isbanded(aDiag, -1, -1)
557591
@test !isbanded(adiag, 1, 1)
592+
@test isbanded(adiag, 1, 1) == isbanded(adiagG, 1, 1) == isbanded(aDiag, 1, 1)
558593
end
559594
@testset "isdiag" begin
560595
@test !isdiag(tridiag)
596+
@test isdiag(tridiag) == isdiag(tridiagG) == isdiag(Tridiag)
561597
@test !isdiag(ubidiag)
598+
@test isdiag(ubidiag) == isdiag(ubidiagG) == isdiag(uBidiag)
562599
@test !isdiag(lbidiag)
600+
@test isdiag(lbidiag) == isdiag(lbidiagG) == isdiag(lBidiag)
563601
@test isdiag(adiag)
602+
@test isdiag(adiag) ==isdiag(adiagG) == isdiag(aDiag)
603+
end
604+
end
605+
606+
@testset "isbanded/istril/istriu with rectangular matrices" begin
607+
@testset "$(size(A))" for A in [zeros(0,4), zeros(2,5), zeros(5,2), zeros(4,0)]
608+
@testset for m in -(size(A,1)-1):(size(A,2)-1)
609+
A .= 0
610+
A[diagind(A, m)] .= 1
611+
G = GenericArray(A)
612+
@testset for (kl,ku) in Iterators.product(-6:6, -6:6)
613+
@test isbanded(A, kl, ku) == isbanded(G, kl, ku) == isempty(A) || (m in (kl:ku))
614+
end
615+
@testset for k in -6:6
616+
@test istriu(A,k) == istriu(G,k) == isempty(A) || (k <= m)
617+
@test istril(A,k) == istril(G,k) == isempty(A) || (k >= m)
618+
end
619+
end
564620
end
565621
end
566622

test/hessenberg.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,30 @@ end
279279
@test H.H == D
280280
end
281281

282+
@testset "istriu/istril forwards to parent" begin
283+
n = 10
284+
@testset "$(nameof(typeof(M)))" for M in [Tridiagonal(rand(n-1), rand(n), rand(n-1)),
285+
Tridiagonal(zeros(n-1), zeros(n), zeros(n-1)),
286+
Diagonal(randn(n)),
287+
Diagonal(zeros(n)),
288+
]
289+
U = UpperHessenberg(M)
290+
A = Array(U)
291+
for k in -n:n
292+
@test istriu(U, k) == istriu(A, k)
293+
@test istril(U, k) == istril(A, k)
294+
end
295+
end
296+
z = zeros(n,n)
297+
P = Matrix{BigFloat}(undef, n, n)
298+
copytrito!(P, z, 'U')
299+
P[diagind(P,-1)] .= 0
300+
U = UpperHessenberg(P)
301+
A = Array(U)
302+
@testset for k in -n:n
303+
@test istriu(U, k) == istriu(A, k)
304+
@test istril(U, k) == istril(A, k)
305+
end
306+
end
307+
282308
end # module TestHessenberg

0 commit comments

Comments
 (0)