@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238
238
Base. isstored (A:: UpperOrLowerTriangular , i:: Int , j:: Int ) =
239
239
_shouldforwardindex (A, i, j) ? Base. isstored (A. data, i, j) : false
240
240
241
- @propagate_inbounds getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T} =
242
- _shouldforwardindex (A, i, j) ? A. data[i,j] : ifelse (i == j, oneunit (T), zero (T))
243
- @propagate_inbounds getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int ) =
244
- _shouldforwardindex (A, i, j) ? A. data[i,j] : diagzero (A,i,j)
241
+ @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T}
242
+ if _shouldforwardindex (A, i, j)
243
+ A. data[i,j]
244
+ else
245
+ @boundscheck checkbounds (A, i, j)
246
+ ifelse (i == j, oneunit (T), zero (T))
247
+ end
248
+ end
249
+ @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int )
250
+ if _shouldforwardindex (A, i, j)
251
+ A. data[i,j]
252
+ else
253
+ @boundscheck checkbounds (A, i, j)
254
+ @inbounds diagzero (A,i,j)
255
+ end
256
+ end
245
257
246
258
_shouldforwardindex (U:: UpperTriangular , b:: BandIndex ) = b. band >= 0
247
259
_shouldforwardindex (U:: LowerTriangular , b:: BandIndex ) = b. band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250
262
251
263
# these specialized getindex methods enable constant-propagation of the band
252
264
Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , b:: BandIndex ) where {T}
253
- _shouldforwardindex (A, b) ? A. data[b] : ifelse (b. band == 0 , oneunit (T), zero (T))
265
+ if _shouldforwardindex (A, b)
266
+ A. data[b]
267
+ else
268
+ @boundscheck checkbounds (A, b)
269
+ ifelse (b. band == 0 , oneunit (T), zero (T))
270
+ end
254
271
end
255
272
Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , b:: BandIndex )
256
- _shouldforwardindex (A, b) ? A. data[b] : diagzero (A. data, b)
273
+ if _shouldforwardindex (A, b)
274
+ A. data[b]
275
+ else
276
+ @boundscheck checkbounds (A, b)
277
+ @inbounds diagzero (A, b)
278
+ end
257
279
end
258
280
259
281
_zero_triangular_half_str (:: Type{<:UpperOrUnitUpperTriangular} ) = " lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265
287
throw (ArgumentError (
266
288
lazy " cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)" ))
267
289
end
268
- @noinline function throw_nononeerror (T, @nospecialize (x), i, j)
290
+ @noinline function throw_nonuniterror (T, @nospecialize (x), i, j)
291
+ check_compatible_type (T, x)
269
292
Tn = nameof (T)
270
293
throw (ArgumentError (
271
294
lazy " cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)" ))
272
295
end
296
+ function check_compatible_type (T, @nospecialize (x))
297
+ ET = eltype (T)
298
+ convert (ET, x) # check that the types are compatible with setindex!
299
+ end
273
300
274
301
@propagate_inbounds function setindex! (A:: UpperTriangular , x, i:: Integer , j:: Integer )
275
302
if i > j
303
+ @boundscheck checkbounds (A, i, j)
276
304
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
277
305
else
278
306
A. data[i,j] = x
282
310
283
311
@propagate_inbounds function setindex! (A:: UnitUpperTriangular , x, i:: Integer , j:: Integer )
284
312
if i > j
313
+ @boundscheck checkbounds (A, i, j)
285
314
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
286
315
elseif i == j
287
- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
316
+ @boundscheck checkbounds (A, i, j)
317
+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
288
318
else
289
319
A. data[i,j] = x
290
320
end
293
323
294
324
@propagate_inbounds function setindex! (A:: LowerTriangular , x, i:: Integer , j:: Integer )
295
325
if i < j
326
+ @boundscheck checkbounds (A, i, j)
296
327
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
297
328
else
298
329
A. data[i,j] = x
302
333
303
334
@propagate_inbounds function setindex! (A:: UnitLowerTriangular , x, i:: Integer , j:: Integer )
304
335
if i < j
336
+ @boundscheck checkbounds (A, i, j)
305
337
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
306
338
elseif i == j
307
- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
339
+ @boundscheck checkbounds (A, i, j)
340
+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
308
341
else
309
342
A. data[i,j] = x
310
343
end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560
593
@eval @inline function _copy! (A:: $UT , B:: $T )
561
594
for dind in diagind (A, IndexStyle (A))
562
595
if A[dind] != B[dind]
563
- throw_nononeerror (typeof (A), B[dind], Tuple (dind)... )
596
+ throw_nonuniterror (typeof (A), B[dind], Tuple (dind)... )
564
597
end
565
598
end
566
599
_copy! ($ T (parent (A)), B)
@@ -741,7 +774,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
741
774
checksize1 (A, B)
742
775
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
743
776
for j in axes (B. data,2 )
744
- @inbounds _modify! (_add, c, A, (j,j))
777
+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
745
778
for i in firstindex (B. data,1 ): (j - 1 )
746
779
@inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
747
780
end
@@ -752,7 +785,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
752
785
checksize1 (A, B)
753
786
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
754
787
for j in axes (B. data,2 )
755
- @inbounds _modify! (_add, c, A, (j,j))
788
+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
756
789
for i in firstindex (B. data,1 ): (j - 1 )
757
790
@inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
758
791
end
@@ -783,7 +816,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
783
816
checksize1 (A, B)
784
817
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
785
818
for j in axes (B. data,2 )
786
- @inbounds _modify! (_add, c, A, (j,j))
819
+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
787
820
for i in (j + 1 ): lastindex (B. data,1 )
788
821
@inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
789
822
end
@@ -794,7 +827,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
794
827
checksize1 (A, B)
795
828
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
796
829
for j in axes (B. data,2 )
797
- @inbounds _modify! (_add, c, A, (j,j))
830
+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
798
831
for i in (j + 1 ): lastindex (B. data,1 )
799
832
@inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
800
833
end
804
837
805
838
function _trirdiv! (A:: UpperTriangular , B:: UpperTriangular , c:: Number )
806
839
checksize1 (A, B)
840
+ isunit = B isa UnitUpperTriangular
807
841
for j in axes (B,2 )
808
842
for i in firstindex (B,1 ): j
809
843
@inbounds A. data[i, j] = B. data[i, j] / c
@@ -813,6 +847,7 @@ function _trirdiv!(A::UpperTriangular, B::UpperTriangular, c::Number)
813
847
end
814
848
function _trirdiv! (A:: LowerTriangular , B:: LowerTriangular , c:: Number )
815
849
checksize1 (A, B)
850
+ isunit = B isa UnitLowerTriangular
816
851
for j in axes (B,2 )
817
852
for i in j: lastindex (B,1 )
818
853
@inbounds A. data[i, j] = B. data[i, j] / c
@@ -822,6 +857,7 @@ function _trirdiv!(A::LowerTriangular, B::LowerTriangular, c::Number)
822
857
end
823
858
function _trildiv! (A:: UpperTriangular , c:: Number , B:: UpperTriangular )
824
859
checksize1 (A, B)
860
+ isunit = B isa UnitUpperTriangular
825
861
for j in axes (B,2 )
826
862
for i in firstindex (B,1 ): j
827
863
@inbounds A. data[i, j] = c \ B. data[i, j]
@@ -831,6 +867,7 @@ function _trildiv!(A::UpperTriangular, c::Number, B::UpperTriangular)
831
867
end
832
868
function _trildiv! (A:: LowerTriangular , c:: Number , B:: LowerTriangular )
833
869
checksize1 (A, B)
870
+ isunit = B isa UnitLowerTriangular
834
871
for j in axes (B,2 )
835
872
for i in j: lastindex (B,1 )
836
873
@inbounds A. data[i, j] = c \ B. data[i, j]
0 commit comments