@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238
238
Base. isstored (A:: UpperOrLowerTriangular , i:: Int , j:: Int ) =
239
239
_shouldforwardindex (A, i, j) ? Base. isstored (A. data, i, j) : false
240
240
241
- @propagate_inbounds getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T} =
242
- _shouldforwardindex (A, i, j) ? A. data[i,j] : ifelse (i == j, oneunit (T), zero (T))
243
- @propagate_inbounds getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int ) =
244
- _shouldforwardindex (A, i, j) ? A. data[i,j] : diagzero (A,i,j)
241
+ @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T}
242
+ if _shouldforwardindex (A, i, j)
243
+ A. data[i,j]
244
+ else
245
+ @boundscheck checkbounds (A, i, j)
246
+ ifelse (i == j, oneunit (T), zero (T))
247
+ end
248
+ end
249
+ @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int )
250
+ if _shouldforwardindex (A, i, j)
251
+ A. data[i,j]
252
+ else
253
+ @boundscheck checkbounds (A, i, j)
254
+ @inbounds diagzero (A,i,j)
255
+ end
256
+ end
245
257
246
258
_shouldforwardindex (U:: UpperTriangular , b:: BandIndex ) = b. band >= 0
247
259
_shouldforwardindex (U:: LowerTriangular , b:: BandIndex ) = b. band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250
262
251
263
# these specialized getindex methods enable constant-propagation of the band
252
264
Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , b:: BandIndex ) where {T}
253
- _shouldforwardindex (A, b) ? A. data[b] : ifelse (b. band == 0 , oneunit (T), zero (T))
265
+ if _shouldforwardindex (A, b)
266
+ A. data[b]
267
+ else
268
+ @boundscheck checkbounds (A, b)
269
+ ifelse (b. band == 0 , oneunit (T), zero (T))
270
+ end
254
271
end
255
272
Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , b:: BandIndex )
256
- _shouldforwardindex (A, b) ? A. data[b] : diagzero (A. data, b)
273
+ if _shouldforwardindex (A, b)
274
+ A. data[b]
275
+ else
276
+ @boundscheck checkbounds (A, b)
277
+ @inbounds diagzero (A, b)
278
+ end
257
279
end
258
280
259
281
_zero_triangular_half_str (:: Type{<:UpperOrUnitUpperTriangular} ) = " lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265
287
throw (ArgumentError (
266
288
lazy " cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)" ))
267
289
end
268
- @noinline function throw_nononeerror (T, @nospecialize (x), i, j)
290
+ @noinline function throw_nonuniterror (T, @nospecialize (x), i, j)
291
+ check_compatible_type (T, x)
269
292
Tn = nameof (T)
270
293
throw (ArgumentError (
271
294
lazy " cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)" ))
272
295
end
296
+ function check_compatible_type (T, @nospecialize (x))
297
+ ET = eltype (T)
298
+ convert (ET, x) # check that the types are compatible with setindex!
299
+ end
273
300
274
301
@propagate_inbounds function setindex! (A:: UpperTriangular , x, i:: Integer , j:: Integer )
275
302
if i > j
303
+ @boundscheck checkbounds (A, i, j)
276
304
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
277
305
else
278
306
A. data[i,j] = x
282
310
283
311
@propagate_inbounds function setindex! (A:: UnitUpperTriangular , x, i:: Integer , j:: Integer )
284
312
if i > j
313
+ @boundscheck checkbounds (A, i, j)
285
314
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
286
315
elseif i == j
287
- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
316
+ @boundscheck checkbounds (A, i, j)
317
+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
288
318
else
289
319
A. data[i,j] = x
290
320
end
293
323
294
324
@propagate_inbounds function setindex! (A:: LowerTriangular , x, i:: Integer , j:: Integer )
295
325
if i < j
326
+ @boundscheck checkbounds (A, i, j)
296
327
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
297
328
else
298
329
A. data[i,j] = x
302
333
303
334
@propagate_inbounds function setindex! (A:: UnitLowerTriangular , x, i:: Integer , j:: Integer )
304
335
if i < j
336
+ @boundscheck checkbounds (A, i, j)
305
337
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
306
338
elseif i == j
307
- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
339
+ @boundscheck checkbounds (A, i, j)
340
+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
308
341
else
309
342
A. data[i,j] = x
310
343
end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560
593
@eval @inline function _copy! (A:: $UT , B:: $T )
561
594
for dind in diagind (A, IndexStyle (A))
562
595
if A[dind] != B[dind]
563
- throw_nononeerror (typeof (A), B[dind], Tuple (dind)... )
596
+ throw_nonuniterror (typeof (A), B[dind], Tuple (dind)... )
564
597
end
565
598
end
566
599
_copy! ($ T (parent (A)), B)
@@ -740,7 +773,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
740
773
checksize1 (A, B)
741
774
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
742
775
for j in axes (B. data,2 )
743
- @inbounds _modify! (_add, c, A, (j,j))
776
+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
744
777
for i in firstindex (B. data,1 ): (j - 1 )
745
778
@inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
746
779
end
@@ -751,7 +784,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
751
784
checksize1 (A, B)
752
785
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
753
786
for j in axes (B. data,2 )
754
- @inbounds _modify! (_add, c, A, (j,j))
787
+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
755
788
for i in firstindex (B. data,1 ): (j - 1 )
756
789
@inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
757
790
end
@@ -782,7 +815,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
782
815
checksize1 (A, B)
783
816
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
784
817
for j in axes (B. data,2 )
785
- @inbounds _modify! (_add, c, A, (j,j))
818
+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
786
819
for i in (j + 1 ): lastindex (B. data,1 )
787
820
@inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
788
821
end
@@ -793,7 +826,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
793
826
checksize1 (A, B)
794
827
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
795
828
for j in axes (B. data,2 )
796
- @inbounds _modify! (_add, c, A, (j,j))
829
+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
797
830
for i in (j + 1 ): lastindex (B. data,1 )
798
831
@inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
799
832
end
@@ -803,36 +836,52 @@ end
803
836
804
837
function _trirdiv! (A:: UpperTriangular , B:: UpperOrUnitUpperTriangular , c:: Number )
805
838
checksize1 (A, B)
839
+ isunit = B isa UnitUpperTriangular
806
840
for j in axes (B,2 )
807
- for i in firstindex (B,1 ): j
808
- @inbounds A[i, j] = B[i, j] / c
841
+ for i in firstindex (B,1 ): j- isunit
842
+ @inbounds A. data[i, j] = B. data[i, j] / c
843
+ end
844
+ if isunit
845
+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
809
846
end
810
847
end
811
848
return A
812
849
end
813
850
function _trirdiv! (A:: LowerTriangular , B:: LowerOrUnitLowerTriangular , c:: Number )
814
851
checksize1 (A, B)
852
+ isunit = B isa UnitLowerTriangular
815
853
for j in axes (B,2 )
816
- for i in j: lastindex (B,1 )
817
- @inbounds A[i, j] = B[i, j] / c
854
+ if isunit
855
+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
856
+ end
857
+ for i in j+ isunit: lastindex (B,1 )
858
+ @inbounds A. data[i, j] = B. data[i, j] / c
818
859
end
819
860
end
820
861
return A
821
862
end
822
863
function _trildiv! (A:: UpperTriangular , c:: Number , B:: UpperOrUnitUpperTriangular )
823
864
checksize1 (A, B)
865
+ isunit = B isa UnitUpperTriangular
824
866
for j in axes (B,2 )
825
- for i in firstindex (B,1 ): j
826
- @inbounds A[i, j] = c \ B[i, j]
867
+ for i in firstindex (B,1 ): j- isunit
868
+ @inbounds A. data[i, j] = c \ B. data[i, j]
869
+ end
870
+ if isunit
871
+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
827
872
end
828
873
end
829
874
return A
830
875
end
831
876
function _trildiv! (A:: LowerTriangular , c:: Number , B:: LowerOrUnitLowerTriangular )
832
877
checksize1 (A, B)
878
+ isunit = B isa UnitLowerTriangular
833
879
for j in axes (B,2 )
834
- for i in j: lastindex (B,1 )
835
- @inbounds A[i, j] = c \ B[i, j]
880
+ if isunit
881
+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
882
+ end
883
+ for i in j+ isunit: lastindex (B,1 )
884
+ @inbounds A. data[i, j] = c \ B. data[i, j]
836
885
end
837
886
end
838
887
return A
0 commit comments