Skip to content

Commit ba99c6e

Browse files
committed
Bounds-checking in triangular indexing branches
1 parent 3393398 commit ba99c6e

File tree

2 files changed

+118
-22
lines changed

2 files changed

+118
-22
lines changed

src/triangular.jl

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238238
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
239239
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false
240240

241-
@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
242-
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
243-
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
244-
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
241+
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
242+
if _shouldforwardindex(A, i, j)
243+
A.data[i,j]
244+
else
245+
@boundscheck checkbounds(A, i, j)
246+
ifelse(i == j, oneunit(T), zero(T))
247+
end
248+
end
249+
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
250+
if _shouldforwardindex(A, i, j)
251+
A.data[i,j]
252+
else
253+
@boundscheck checkbounds(A, i, j)
254+
@inbounds diagzero(A,i,j)
255+
end
256+
end
245257

246258
_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
247259
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250262

251263
# these specialized getindex methods enable constant-propagation of the band
252264
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
253-
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
265+
if _shouldforwardindex(A, b)
266+
A.data[b]
267+
else
268+
@boundscheck checkbounds(A, b)
269+
ifelse(b.band == 0, oneunit(T), zero(T))
270+
end
254271
end
255272
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
256-
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
273+
if _shouldforwardindex(A, b)
274+
A.data[b]
275+
else
276+
@boundscheck checkbounds(A, b)
277+
@inbounds diagzero(A, b)
278+
end
257279
end
258280

259281
_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265287
throw(ArgumentError(
266288
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
267289
end
268-
@noinline function throw_nononeerror(T, @nospecialize(x), i, j)
290+
@noinline function throw_nonuniterror(T, @nospecialize(x), i, j)
291+
check_compatible_type(T, x)
269292
Tn = nameof(T)
270293
throw(ArgumentError(
271294
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
272295
end
296+
function check_compatible_type(T, @nospecialize(x))
297+
ET = eltype(T)
298+
convert(ET, x) # check that the types are compatible with setindex!
299+
end
273300

274301
@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
275302
if i > j
303+
@boundscheck checkbounds(A, i, j)
276304
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
277305
else
278306
A.data[i,j] = x
@@ -282,9 +310,11 @@ end
282310

283311
@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
284312
if i > j
313+
@boundscheck checkbounds(A, i, j)
285314
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
286315
elseif i == j
287-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
316+
@boundscheck checkbounds(A, i, j)
317+
x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j)
288318
else
289319
A.data[i,j] = x
290320
end
@@ -293,6 +323,7 @@ end
293323

294324
@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
295325
if i < j
326+
@boundscheck checkbounds(A, i, j)
296327
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
297328
else
298329
A.data[i,j] = x
@@ -302,9 +333,11 @@ end
302333

303334
@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
304335
if i < j
336+
@boundscheck checkbounds(A, i, j)
305337
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
306338
elseif i == j
307-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
339+
@boundscheck checkbounds(A, i, j)
340+
x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j)
308341
else
309342
A.data[i,j] = x
310343
end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560593
@eval @inline function _copy!(A::$UT, B::$T)
561594
for dind in diagind(A, IndexStyle(A))
562595
if A[dind] != B[dind]
563-
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
596+
throw_nonuniterror(typeof(A), B[dind], Tuple(dind)...)
564597
end
565598
end
566599
_copy!($T(parent(A)), B)
@@ -740,7 +773,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
740773
checksize1(A, B)
741774
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
742775
for j in axes(B.data,2)
743-
@inbounds _modify!(_add, c, A, (j,j))
776+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
744777
for i in firstindex(B.data,1):(j - 1)
745778
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
746779
end
@@ -751,7 +784,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
751784
checksize1(A, B)
752785
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
753786
for j in axes(B.data,2)
754-
@inbounds _modify!(_add, c, A, (j,j))
787+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
755788
for i in firstindex(B.data,1):(j - 1)
756789
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
757790
end
@@ -782,7 +815,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
782815
checksize1(A, B)
783816
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
784817
for j in axes(B.data,2)
785-
@inbounds _modify!(_add, c, A, (j,j))
818+
@inbounds _modify!(_add, B[BandIndex(0,j)] *c, A, (j,j))
786819
for i in (j + 1):lastindex(B.data,1)
787820
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
788821
end
@@ -793,7 +826,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
793826
checksize1(A, B)
794827
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
795828
for j in axes(B.data,2)
796-
@inbounds _modify!(_add, c, A, (j,j))
829+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
797830
for i in (j + 1):lastindex(B.data,1)
798831
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
799832
end
@@ -803,36 +836,52 @@ end
803836

804837
function _trirdiv!(A::UpperTriangular, B::UpperOrUnitUpperTriangular, c::Number)
805838
checksize1(A, B)
839+
isunit = B isa UnitUpperTriangular
806840
for j in axes(B,2)
807-
for i in firstindex(B,1):j
808-
@inbounds A[i, j] = B[i, j] / c
841+
for i in firstindex(B,1):j-isunit
842+
@inbounds A.data[i, j] = B.data[i, j] / c
843+
end
844+
if isunit
845+
@inbounds A.data[j, j] = B[BandIndex(0,j)] / c
809846
end
810847
end
811848
return A
812849
end
813850
function _trirdiv!(A::LowerTriangular, B::LowerOrUnitLowerTriangular, c::Number)
814851
checksize1(A, B)
852+
isunit = B isa UnitLowerTriangular
815853
for j in axes(B,2)
816-
for i in j:lastindex(B,1)
817-
@inbounds A[i, j] = B[i, j] / c
854+
if isunit
855+
@inbounds A.data[j, j] = B[BandIndex(0,j)] / c
856+
end
857+
for i in j+isunit:lastindex(B,1)
858+
@inbounds A.data[i, j] = B.data[i, j] / c
818859
end
819860
end
820861
return A
821862
end
822863
function _trildiv!(A::UpperTriangular, c::Number, B::UpperOrUnitUpperTriangular)
823864
checksize1(A, B)
865+
isunit = B isa UnitUpperTriangular
824866
for j in axes(B,2)
825-
for i in firstindex(B,1):j
826-
@inbounds A[i, j] = c \ B[i, j]
867+
for i in firstindex(B,1):j-isunit
868+
@inbounds A.data[i, j] = c \ B.data[i, j]
869+
end
870+
if isunit
871+
@inbounds A.data[j, j] = c \ B[BandIndex(0,j)]
827872
end
828873
end
829874
return A
830875
end
831876
function _trildiv!(A::LowerTriangular, c::Number, B::LowerOrUnitLowerTriangular)
832877
checksize1(A, B)
878+
isunit = B isa UnitLowerTriangular
833879
for j in axes(B,2)
834-
for i in j:lastindex(B,1)
835-
@inbounds A[i, j] = c \ B[i, j]
880+
if isunit
881+
@inbounds A.data[j, j] = c \ B[BandIndex(0,j)]
882+
end
883+
for i in j+isunit:lastindex(B,1)
884+
@inbounds A.data[i, j] = c \ B.data[i, j]
836885
end
837886
end
838887
return A

test/triangular.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,4 +944,51 @@ end
944944
@test 2*U == 2*M
945945
end
946946

947+
@testset "indexing checks" begin
948+
@testset "getindex" begin
949+
U = UnitUpperTriangular(P)
950+
@test_throws BoundsError U[0,0]
951+
@test_throws BoundsError U[1,0]
952+
@test_throws BoundsError U[BandIndex(0,0)]
953+
@test_throws BoundsError U[BandIndex(-1,0)]
954+
955+
U = UpperTriangular(P)
956+
@test_throws BoundsError U[1,0]
957+
@test_throws BoundsError U[BandIndex(-1,0)]
958+
959+
L = UnitLowerTriangular(P)
960+
@test_throws BoundsError L[0,0]
961+
@test_throws BoundsError L[0,1]
962+
@test_throws BoundsError U[BandIndex(0,0)]
963+
@test_throws BoundsError U[BandIndex(1,0)]
964+
965+
L = LowerTriangular(P)
966+
@test_throws BoundsError L[0,1]
967+
@test_throws BoundsError L[BandIndex(1,0)]
968+
end
969+
@testset "setindex!" begin
970+
P = [1 2; 3 4]
971+
A = SizedArrays.SizedArray{(2,2)}(P)
972+
M = fill(A, 2, 2)
973+
U = UnitUpperTriangular(M)
974+
@test_throws "Cannot `convert` an object of type Int64" U[1,1] = 1
975+
L = UnitLowerTriangular(M)
976+
@test_throws "Cannot `convert` an object of type Int64" L[1,1] = 1
977+
978+
U = UnitUpperTriangular(P)
979+
@test_throws BoundsError U[0,0] = 1
980+
@test_throws BoundsError U[1,0] = 0
981+
982+
U = UpperTriangular(P)
983+
@test_throws BoundsError U[1,0] = 0
984+
985+
L = UnitLowerTriangular(P)
986+
@test_throws BoundsError L[0,0] = 1
987+
@test_throws BoundsError L[0,1] = 0
988+
989+
L = LowerTriangular(P)
990+
@test_throws BoundsError L[0,1] = 0
991+
end
992+
end
993+
947994
end # module TestTriangular

0 commit comments

Comments
 (0)