Skip to content

Commit 3fa31c5

Browse files
committed
Bounds-checking in triangular indexing branches
1 parent 3537c3a commit 3fa31c5

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

src/triangular.jl

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238238
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
239239
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false
240240

241-
@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
242-
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
243-
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
244-
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
241+
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
242+
if _shouldforwardindex(A, i, j)
243+
A.data[i,j]
244+
else
245+
@boundscheck checkbounds(A, i, j)
246+
ifelse(i == j, oneunit(T), zero(T))
247+
end
248+
end
249+
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
250+
if _shouldforwardindex(A, i, j)
251+
A.data[i,j]
252+
else
253+
@boundscheck checkbounds(A, i, j)
254+
@inbounds diagzero(A,i,j)
255+
end
256+
end
245257

246258
_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
247259
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250262

251263
# these specialized getindex methods enable constant-propagation of the band
252264
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
253-
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
265+
if _shouldforwardindex(A, b)
266+
A.data[b]
267+
else
268+
@boundscheck checkbounds(A, b)
269+
ifelse(b.band == 0, oneunit(T), zero(T))
270+
end
254271
end
255272
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
256-
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
273+
if _shouldforwardindex(A, b)
274+
A.data[b]
275+
else
276+
@boundscheck checkbounds(A, b)
277+
@inbounds diagzero(A, b)
278+
end
257279
end
258280

259281
_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265287
throw(ArgumentError(
266288
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
267289
end
268-
@noinline function throw_nononeerror(T, @nospecialize(x), i, j)
290+
@noinline function throw_nonuniterror(T, @nospecialize(x), i, j)
291+
check_compatible_type(T, x)
269292
Tn = nameof(T)
270293
throw(ArgumentError(
271294
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
272295
end
296+
function check_compatible_type(T, @nospecialize(x))
297+
ET = eltype(T)
298+
convert(ET, x) # check that the types are compatible with setindex!
299+
end
273300

274301
@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
275302
if i > j
303+
@boundscheck checkbounds(A, i, j)
276304
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
277305
else
278306
A.data[i,j] = x
@@ -282,9 +310,11 @@ end
282310

283311
@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
284312
if i > j
313+
@boundscheck checkbounds(A, i, j)
285314
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
286315
elseif i == j
287-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
316+
@boundscheck checkbounds(A, i, j)
317+
x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j)
288318
else
289319
A.data[i,j] = x
290320
end
@@ -293,6 +323,7 @@ end
293323

294324
@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
295325
if i < j
326+
@boundscheck checkbounds(A, i, j)
296327
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
297328
else
298329
A.data[i,j] = x
@@ -302,9 +333,11 @@ end
302333

303334
@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
304335
if i < j
336+
@boundscheck checkbounds(A, i, j)
305337
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
306338
elseif i == j
307-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
339+
@boundscheck checkbounds(A, i, j)
340+
x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j)
308341
else
309342
A.data[i,j] = x
310343
end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560593
@eval @inline function _copy!(A::$UT, B::$T)
561594
for dind in diagind(A, IndexStyle(A))
562595
if A[dind] != B[dind]
563-
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
596+
throw_nonuniterror(typeof(A), B[dind], Tuple(dind)...)
564597
end
565598
end
566599
_copy!($T(parent(A)), B)
@@ -741,7 +774,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
741774
checksize1(A, B)
742775
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
743776
for j in axes(B.data,2)
744-
@inbounds _modify!(_add, c, A, (j,j))
777+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
745778
for i in firstindex(B.data,1):(j - 1)
746779
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
747780
end
@@ -752,7 +785,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
752785
checksize1(A, B)
753786
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
754787
for j in axes(B.data,2)
755-
@inbounds _modify!(_add, c, A, (j,j))
788+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
756789
for i in firstindex(B.data,1):(j - 1)
757790
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
758791
end
@@ -783,7 +816,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
783816
checksize1(A, B)
784817
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
785818
for j in axes(B.data,2)
786-
@inbounds _modify!(_add, c, A, (j,j))
819+
@inbounds _modify!(_add, B[BandIndex(0,j)] *c, A, (j,j))
787820
for i in (j + 1):lastindex(B.data,1)
788821
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
789822
end
@@ -794,7 +827,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
794827
checksize1(A, B)
795828
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
796829
for j in axes(B.data,2)
797-
@inbounds _modify!(_add, c, A, (j,j))
830+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
798831
for i in (j + 1):lastindex(B.data,1)
799832
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
800833
end
@@ -804,6 +837,7 @@ end
804837

805838
function _trirdiv!(A::UpperTriangular, B::UpperTriangular, c::Number)
806839
checksize1(A, B)
840+
isunit = B isa UnitUpperTriangular
807841
for j in axes(B,2)
808842
for i in firstindex(B,1):j
809843
@inbounds A.data[i, j] = B.data[i, j] / c
@@ -813,6 +847,7 @@ function _trirdiv!(A::UpperTriangular, B::UpperTriangular, c::Number)
813847
end
814848
function _trirdiv!(A::LowerTriangular, B::LowerTriangular, c::Number)
815849
checksize1(A, B)
850+
isunit = B isa UnitLowerTriangular
816851
for j in axes(B,2)
817852
for i in j:lastindex(B,1)
818853
@inbounds A.data[i, j] = B.data[i, j] / c
@@ -822,6 +857,7 @@ function _trirdiv!(A::LowerTriangular, B::LowerTriangular, c::Number)
822857
end
823858
function _trildiv!(A::UpperTriangular, c::Number, B::UpperTriangular)
824859
checksize1(A, B)
860+
isunit = B isa UnitUpperTriangular
825861
for j in axes(B,2)
826862
for i in firstindex(B,1):j
827863
@inbounds A.data[i, j] = c \ B.data[i, j]
@@ -831,6 +867,7 @@ function _trildiv!(A::UpperTriangular, c::Number, B::UpperTriangular)
831867
end
832868
function _trildiv!(A::LowerTriangular, c::Number, B::LowerTriangular)
833869
checksize1(A, B)
870+
isunit = B isa UnitLowerTriangular
834871
for j in axes(B,2)
835872
for i in j:lastindex(B,1)
836873
@inbounds A.data[i, j] = c \ B.data[i, j]

test/triangular.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,4 +966,51 @@ end
966966
end
967967
end
968968

969+
@testset "indexing checks" begin
970+
@testset "getindex" begin
971+
U = UnitUpperTriangular(P)
972+
@test_throws BoundsError U[0,0]
973+
@test_throws BoundsError U[1,0]
974+
@test_throws BoundsError U[BandIndex(0,0)]
975+
@test_throws BoundsError U[BandIndex(-1,0)]
976+
977+
U = UpperTriangular(P)
978+
@test_throws BoundsError U[1,0]
979+
@test_throws BoundsError U[BandIndex(-1,0)]
980+
981+
L = UnitLowerTriangular(P)
982+
@test_throws BoundsError L[0,0]
983+
@test_throws BoundsError L[0,1]
984+
@test_throws BoundsError U[BandIndex(0,0)]
985+
@test_throws BoundsError U[BandIndex(1,0)]
986+
987+
L = LowerTriangular(P)
988+
@test_throws BoundsError L[0,1]
989+
@test_throws BoundsError L[BandIndex(1,0)]
990+
end
991+
@testset "setindex!" begin
992+
P = [1 2; 3 4]
993+
A = SizedArrays.SizedArray{(2,2)}(P)
994+
M = fill(A, 2, 2)
995+
U = UnitUpperTriangular(M)
996+
@test_throws "Cannot `convert` an object of type Int64" U[1,1] = 1
997+
L = UnitLowerTriangular(M)
998+
@test_throws "Cannot `convert` an object of type Int64" L[1,1] = 1
999+
1000+
U = UnitUpperTriangular(P)
1001+
@test_throws BoundsError U[0,0] = 1
1002+
@test_throws BoundsError U[1,0] = 0
1003+
1004+
U = UpperTriangular(P)
1005+
@test_throws BoundsError U[1,0] = 0
1006+
1007+
L = UnitLowerTriangular(P)
1008+
@test_throws BoundsError L[0,0] = 1
1009+
@test_throws BoundsError L[0,1] = 0
1010+
1011+
L = LowerTriangular(P)
1012+
@test_throws BoundsError L[0,1] = 0
1013+
end
1014+
end
1015+
9691016
end # module TestTriangular

0 commit comments

Comments
 (0)