Skip to content

Commit e451927

Browse files
authored
Replace MulAddMul by alpha,beta in __muldiag (#56360)
This PR replaces `MulAddMul` arguments by `alpha, beta` pairs in the multiplication methods involving `Diagonal` matrices, and constructs the objects exactly where they are required. Such an approach improves latency. ```julia julia> D = Diagonal(1:2000); A = rand(size(D)...); C = similar(A); julia> @time mul!(C, A, D, 1, 2); # first-run latency is reduced 0.129741 seconds (180.18 k allocations: 9.607 MiB, 88.87% compilation time) # nightly v"1.12.0-DEV.1505" 0.083005 seconds (146.68 k allocations: 7.442 MiB, 82.94% compilation time) # this PR julia> @Btime mul!($C, $A, $D, 1, 2); # runtime performance is unaffected 4.983 ms (0 allocations: 0 bytes) # nightly 4.938 ms (0 allocations: 0 bytes) # this PR ``` This PR sets the stage for a similar change for `Bidiagonal`/`Tridiaognal` matrices, which would lead to a bigger reduction in latencies.
1 parent 782490a commit e451927

File tree

2 files changed

+63
-53
lines changed

2 files changed

+63
-53
lines changed

src/bidiag.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,14 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
472472
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
473473
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
474474
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
475-
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) =
476-
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
477-
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
478-
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
475+
for T in (:AbstractMatrix, :Diagonal)
476+
@eval begin
477+
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::$T, alpha::Number, beta::Number) =
478+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
479+
@inline _mul!(C::AbstractMatrix, A::$T, B::BandedMatrix, alpha::Number, beta::Number) =
480+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
481+
end
482+
end
479483
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
480484
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
481485

@@ -831,6 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
831835
C
832836
end
833837

838+
_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
839+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
834840
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
835841
require_one_based_indexing(C)
836842
check_A_mul_B!_sizes(size(C), size(A), size(B))
@@ -1067,6 +1073,8 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10671073
C
10681074
end
10691075

1076+
_mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) =
1077+
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
10701078
_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) =
10711079
_dibimul!(C, A, B, _add)
10721080
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =

src/diagonal.jl

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,13 @@ function lmul!(D::Diagonal, T::Tridiagonal)
397397
return T
398398
end
399399

400-
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
400+
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
401401
@inbounds for j in axes(B, 2)
402402
@simd for i in axes(B, 1)
403-
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
403+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
404404
end
405405
end
406-
out
406+
return out
407407
end
408408
_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
409409
_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
@@ -418,116 +418,118 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
418418
end
419419
_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
420420
_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1
421-
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
421+
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, alpha::Number, beta::Number)
422422
isunit = B isa UnitUpperOrUnitLowerTriangular
423423
out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B)
424424
for j in axes(B, 2)
425425
# store the diagonal separately for unit triangular matrices
426426
if isunit
427-
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
427+
@inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j))
428428
end
429429
# The indices of out corresponding to the stored indices of B
430430
rowrange = _rowrange_tri_stored(B, j)
431431
@inbounds @simd for i in rowrange
432-
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
432+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
433433
end
434434
# Fill the indices of out corresponding to the zeros of B
435435
# we only fill these if out and B don't have matching zeros
436436
if !_has_matching_zeros(out, B)
437437
rowrange = _rowrange_tri_zeros(B, j)
438438
@inbounds @simd for i in rowrange
439-
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
439+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
440440
end
441441
end
442442
end
443443
return out
444444
end
445445

446-
@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
447-
beta = _add.beta
448-
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
446+
@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number)
449447
@inbounds for j in axes(A, 2)
450-
dja = _add(D.diag[j])
448+
dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j])
451449
@simd for i in axes(A, 1)
452-
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
450+
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
453451
end
454452
end
455-
out
453+
return out
454+
end
455+
456+
function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number)
457+
__muldiag_nonzeroalpha_right!(out, A, D, alpha, beta)
456458
end
457-
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
459+
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, alpha::Number, beta::Number)
458460
isunit = A isa UnitUpperOrUnitLowerTriangular
459-
beta = _add.beta
460-
# since alpha is multiplied to the diagonal element of D,
461-
# we may skip alpha in the second multiplication by setting ais1 to true
462-
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
463461
# if both A and out have the same upper/lower triangular structure,
464462
# we may directly read and write from the parents
465-
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
463+
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
466464
for j in axes(A, 2)
467-
dja = _add(@inbounds D.diag[j])
465+
dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j])
468466
# store the diagonal separately for unit triangular matrices
469467
if isunit
470-
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
468+
# since alpha is multiplied to the diagonal element of D,
469+
# we may skip alpha in the second multiplication by setting ais1 to true
470+
@inbounds @stable_muladdmul _modify!(MulAddMul(true,beta), A[j,j] * dja, out, (j,j))
471471
end
472472
# indices of out corresponding to the stored indices of A
473473
rowrange = _rowrange_tri_stored(A, j)
474474
@inbounds @simd for i in rowrange
475-
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
475+
# since alpha is multiplied to the diagonal element of D,
476+
# we may skip alpha in the second multiplication by setting ais1 to true
477+
@stable_muladdmul _modify!(MulAddMul(true,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
476478
end
477479
# Fill the indices of out corresponding to the zeros of A
478480
# we only fill these if out and A don't have matching zeros
479481
if !_has_matching_zeros(out, A)
480482
rowrange = _rowrange_tri_zeros(A, j)
481483
@inbounds @simd for i in rowrange
482-
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
484+
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
483485
end
484486
end
485487
end
486-
out
488+
return out
489+
end
490+
491+
# ambiguity resolution
492+
function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number)
493+
__muldiag_nonzeroalpha_right!(out, D1, D2, alpha, beta)
487494
end
488495

489-
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
496+
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number)
490497
d1 = D1.diag
491498
d2 = D2.diag
492499
outd = out.diag
493500
@inbounds @simd for i in eachindex(d1, d2, outd)
494-
_modify!(_add, d1[i] * d2[i], outd, i)
501+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i)
495502
end
496-
out
497-
end
498-
499-
# ambiguity resolution
500-
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
501-
@inbounds for j in axes(D2, 2), i in axes(D2, 1)
502-
_modify!(_add, D1.diag[i] * D2[i,j], out, (i,j))
503-
end
504-
out
503+
return out
505504
end
506505

507-
# muldiag mainly handles the zero-alpha case, so that we need only
506+
# muldiag handles the zero-alpha case, so that we need only
508507
# specialize the non-trivial case
509-
function _mul_diag!(out, A, B, _add)
508+
function _mul_diag!(out, A, B, alpha, beta)
510509
require_one_based_indexing(out, A, B)
511510
_muldiag_size_check(size(out), size(A), size(B))
512-
alpha, beta = _add.alpha, _add.beta
513511
if iszero(alpha)
514512
_rmul_or_fill!(out, beta)
515513
else
516-
__muldiag_nonzeroalpha!(out, A, B, _add)
514+
__muldiag_nonzeroalpha!(out, A, B, alpha, beta)
517515
end
518516
return out
519517
end
520518

521-
_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) =
522-
_mul_diag!(out, D, V, _add)
523-
_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) =
524-
_mul_diag!(out, D, B, _add)
525-
_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) =
526-
_mul_diag!(out, A, D, _add)
527-
_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) =
528-
_mul_diag!(C, Da, Db, _add)
529-
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
530-
_mul_diag!(C, Da, Db, _add)
519+
_mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
520+
_mul_diag!(out, D, V, alpha, beta)
521+
_mul!(out::AbstractMatrix, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
522+
_mul_diag!(out, D, V, alpha, beta)
523+
for MT in (:AbstractMatrix, :AbstractTriangular)
524+
@eval begin
525+
_mul!(out::AbstractMatrix, D::Diagonal, B::$MT, alpha::Number, beta::Number) =
526+
_mul_diag!(out, D, B, alpha, beta)
527+
_mul!(out::AbstractMatrix, A::$MT, D::Diagonal, alpha::Number, beta::Number) =
528+
_mul_diag!(out, A, D, alpha, beta)
529+
end
530+
end
531+
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
532+
_mul_diag!(C, Da, Db, alpha, beta)
531533

532534
function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
533535
_muldiag_size_check(size(Da), size(A))

0 commit comments

Comments
 (0)