@@ -397,13 +397,13 @@ function lmul!(D::Diagonal, T::Tridiagonal)
397
397
return T
398
398
end
399
399
400
- @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add :: MulAddMul )
400
+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, alpha :: Number , beta :: Number )
401
401
@inbounds for j in axes (B, 2 )
402
402
@simd for i in axes (B, 1 )
403
- _modify! (_add , D. diag[i] * B[i,j], out, (i,j))
403
+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[i] * B[i,j], out, (i,j))
404
404
end
405
405
end
406
- out
406
+ return out
407
407
end
408
408
_has_matching_zeros (out:: UpperOrUnitUpperTriangular , A:: UpperOrUnitUpperTriangular ) = true
409
409
_has_matching_zeros (out:: LowerOrUnitLowerTriangular , A:: LowerOrUnitLowerTriangular ) = true
@@ -418,116 +418,118 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
418
418
end
419
419
_rowrange_tri_zeros (B:: UpperOrUnitUpperTriangular , col) = col+ 1 : size (B,1 )
420
420
_rowrange_tri_zeros (B:: LowerOrUnitLowerTriangular , col) = 1 : col- 1
421
- function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add :: MulAddMul )
421
+ function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , alpha :: Number , beta :: Number )
422
422
isunit = B isa UnitUpperOrUnitLowerTriangular
423
423
out_maybeparent, B_maybeparent = _has_matching_zeros (out, B) ? (parent (out), parent (B)) : (out, B)
424
424
for j in axes (B, 2 )
425
425
# store the diagonal separately for unit triangular matrices
426
426
if isunit
427
- @inbounds _modify! (_add , D. diag[j] * B[j,j], out, (j,j))
427
+ @inbounds @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[j] * B[j,j], out, (j,j))
428
428
end
429
429
# The indices of out corresponding to the stored indices of B
430
430
rowrange = _rowrange_tri_stored (B, j)
431
431
@inbounds @simd for i in rowrange
432
- _modify! (_add , D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
432
+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
433
433
end
434
434
# Fill the indices of out corresponding to the zeros of B
435
435
# we only fill these if out and B don't have matching zeros
436
436
if ! _has_matching_zeros (out, B)
437
437
rowrange = _rowrange_tri_zeros (B, j)
438
438
@inbounds @simd for i in rowrange
439
- _modify! (_add , D. diag[i] * B[i,j], out, (i,j))
439
+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[i] * B[i,j], out, (i,j))
440
440
end
441
441
end
442
442
end
443
443
return out
444
444
end
445
445
446
- @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
447
- beta = _add. beta
448
- _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
446
+ @inline function __muldiag_nonzeroalpha_right! (out, A, D:: Diagonal , alpha:: Number , beta:: Number )
449
447
@inbounds for j in axes (A, 2 )
450
- dja = _add (D. diag[j])
448
+ dja = @stable_muladdmul MulAddMul (alpha, false ) (D. diag[j])
451
449
@simd for i in axes (A, 1 )
452
- _modify! (_add_aisone , A[i,j] * dja, out, (i,j))
450
+ @stable_muladdmul _modify! (MulAddMul ( true ,beta) , A[i,j] * dja, out, (i,j))
453
451
end
454
452
end
455
- out
453
+ return out
454
+ end
455
+
456
+ function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , alpha:: Number , beta:: Number )
457
+ __muldiag_nonzeroalpha_right! (out, A, D, alpha, beta)
456
458
end
457
- function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add :: MulAddMul{ais1,bis0} ) where {ais1,bis0}
459
+ function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , alpha :: Number , beta :: Number )
458
460
isunit = A isa UnitUpperOrUnitLowerTriangular
459
- beta = _add. beta
460
- # since alpha is multiplied to the diagonal element of D,
461
- # we may skip alpha in the second multiplication by setting ais1 to true
462
- _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
463
461
# if both A and out have the same upper/lower triangular structure,
464
462
# we may directly read and write from the parents
465
- out_maybeparent, A_maybeparent = _has_matching_zeros (out, A) ? (parent (out), parent (A)) : (out, A)
463
+ out_maybeparent, A_maybeparent = _has_matching_zeros (out, A) ? (parent (out), parent (A)) : (out, A)
466
464
for j in axes (A, 2 )
467
- dja = _add (@inbounds D. diag[j])
465
+ dja = @stable_muladdmul MulAddMul (alpha, false ) (@inbounds D. diag[j])
468
466
# store the diagonal separately for unit triangular matrices
469
467
if isunit
470
- @inbounds _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
468
+ # since alpha is multiplied to the diagonal element of D,
469
+ # we may skip alpha in the second multiplication by setting ais1 to true
470
+ @inbounds @stable_muladdmul _modify! (MulAddMul (true ,beta), A[j,j] * dja, out, (j,j))
471
471
end
472
472
# indices of out corresponding to the stored indices of A
473
473
rowrange = _rowrange_tri_stored (A, j)
474
474
@inbounds @simd for i in rowrange
475
- _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
475
+ # since alpha is multiplied to the diagonal element of D,
476
+ # we may skip alpha in the second multiplication by setting ais1 to true
477
+ @stable_muladdmul _modify! (MulAddMul (true ,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
476
478
end
477
479
# Fill the indices of out corresponding to the zeros of A
478
480
# we only fill these if out and A don't have matching zeros
479
481
if ! _has_matching_zeros (out, A)
480
482
rowrange = _rowrange_tri_zeros (A, j)
481
483
@inbounds @simd for i in rowrange
482
- _modify! (_add_aisone , A[i,j] * dja, out, (i,j))
484
+ @stable_muladdmul _modify! (MulAddMul ( true ,beta) , A[i,j] * dja, out, (i,j))
483
485
end
484
486
end
485
487
end
486
- out
488
+ return out
489
+ end
490
+
491
+ # ambiguity resolution
492
+ function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , alpha:: Number , beta:: Number )
493
+ __muldiag_nonzeroalpha_right! (out, D1, D2, alpha, beta)
487
494
end
488
495
489
- @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add :: MulAddMul )
496
+ @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , alpha :: Number , beta :: Number )
490
497
d1 = D1. diag
491
498
d2 = D2. diag
492
499
outd = out. diag
493
500
@inbounds @simd for i in eachindex (d1, d2, outd)
494
- _modify! (_add , d1[i] * d2[i], outd, i)
501
+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , d1[i] * d2[i], outd, i)
495
502
end
496
- out
497
- end
498
-
499
- # ambiguity resolution
500
- @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
501
- @inbounds for j in axes (D2, 2 ), i in axes (D2, 1 )
502
- _modify! (_add, D1. diag[i] * D2[i,j], out, (i,j))
503
- end
504
- out
503
+ return out
505
504
end
506
505
507
- # muldiag mainly handles the zero-alpha case, so that we need only
506
+ # muldiag handles the zero-alpha case, so that we need only
508
507
# specialize the non-trivial case
509
- function _mul_diag! (out, A, B, _add )
508
+ function _mul_diag! (out, A, B, alpha, beta )
510
509
require_one_based_indexing (out, A, B)
511
510
_muldiag_size_check (size (out), size (A), size (B))
512
- alpha, beta = _add. alpha, _add. beta
513
511
if iszero (alpha)
514
512
_rmul_or_fill! (out, beta)
515
513
else
516
- __muldiag_nonzeroalpha! (out, A, B, _add )
514
+ __muldiag_nonzeroalpha! (out, A, B, alpha, beta )
517
515
end
518
516
return out
519
517
end
520
518
521
- _mul! (out:: AbstractVecOrMat , D:: Diagonal , V:: AbstractVector , _add) =
522
- _mul_diag! (out, D, V, _add)
523
- _mul! (out:: AbstractMatrix , D:: Diagonal , B:: AbstractMatrix , _add) =
524
- _mul_diag! (out, D, B, _add)
525
- _mul! (out:: AbstractMatrix , A:: AbstractMatrix , D:: Diagonal , _add) =
526
- _mul_diag! (out, A, D, _add)
527
- _mul! (C:: Diagonal , Da:: Diagonal , Db:: Diagonal , _add) =
528
- _mul_diag! (C, Da, Db, _add)
529
- _mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , _add) =
530
- _mul_diag! (C, Da, Db, _add)
519
+ _mul! (out:: AbstractVector , D:: Diagonal , V:: AbstractVector , alpha:: Number , beta:: Number ) =
520
+ _mul_diag! (out, D, V, alpha, beta)
521
+ _mul! (out:: AbstractMatrix , D:: Diagonal , V:: AbstractVector , alpha:: Number , beta:: Number ) =
522
+ _mul_diag! (out, D, V, alpha, beta)
523
+ for MT in (:AbstractMatrix , :AbstractTriangular )
524
+ @eval begin
525
+ _mul! (out:: AbstractMatrix , D:: Diagonal , B:: $MT , alpha:: Number , beta:: Number ) =
526
+ _mul_diag! (out, D, B, alpha, beta)
527
+ _mul! (out:: AbstractMatrix , A:: $MT , D:: Diagonal , alpha:: Number , beta:: Number ) =
528
+ _mul_diag! (out, A, D, alpha, beta)
529
+ end
530
+ end
531
+ _mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number ) =
532
+ _mul_diag! (C, Da, Db, alpha, beta)
531
533
532
534
function (* )(Da:: Diagonal , A:: AbstractMatrix , Db:: Diagonal )
533
535
_muldiag_size_check (size (Da), size (A))
0 commit comments