Skip to content

Commit 092f60b

Browse files
authored
Add lmul! and rmul! for Bidiagonal (#51777)
1 parent fd15da0 commit 092f60b

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

src/bidiag.jl

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,13 +426,17 @@ end
426426

427427
const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or BiDiTriSym
428428
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
429+
const TriSym = Union{Tridiagonal,SymTridiagonal}
429430
const BiTri = Union{Bidiagonal,Tridiagonal}
430431
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
431432
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
432433
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
433434
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
434435
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
435436

437+
lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
438+
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
439+
436440
function check_A_mul_B!_sizes(C, A, B)
437441
mA, nA = size(A)
438442
mB, nB = size(B)
@@ -460,7 +464,11 @@ function _diag(A::Bidiagonal, k)
460464
end
461465
end
462466

463-
function _mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, _add::MulAddMul = MulAddMul())
467+
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul = MulAddMul()) =
468+
_bibimul!(C, A, B, _add)
469+
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
470+
_bibimul!(C, A, B, _add)
471+
function _bibimul!(C, A, B, _add)
464472
check_A_mul_B!_sizes(C, A, B)
465473
n = size(A,1)
466474
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
@@ -583,7 +591,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
583591
C
584592
end
585593

586-
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMul = MulAddMul())
594+
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul = MulAddMul())
587595
require_one_based_indexing(C, A)
588596
check_A_mul_B!_sizes(C, A, B)
589597
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
@@ -618,7 +626,37 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMu
618626
C
619627
end
620628

621-
function _mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, _add::MulAddMul = MulAddMul())
629+
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul = MulAddMul())
630+
require_one_based_indexing(C, A)
631+
check_A_mul_B!_sizes(C, A, B)
632+
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
633+
if size(A, 1) <= 3 || size(B, 2) <= 1
634+
return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
635+
end
636+
m, n = size(A)
637+
@inbounds if B.uplo == 'U'
638+
for i in 1:m
639+
for j in n:-1:2
640+
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
641+
end
642+
_modify!(_add, A[i,1] * B.dv[1], C, (i, 1))
643+
end
644+
else # uplo == 'L'
645+
for i in 1:m
646+
for j in 1:n-1
647+
_modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j))
648+
end
649+
_modify!(_add, A[i,n] * B.dv[n], C, (i, n))
650+
end
651+
end
652+
C
653+
end
654+
655+
_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
656+
_dibimul!(C, A, B, _add)
657+
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul = MulAddMul()) =
658+
_dibimul!(C, A, B, _add)
659+
function _dibimul!(C, A, B, _add)
622660
require_one_based_indexing(C)
623661
check_A_mul_B!_sizes(C, A, B)
624662
n = size(A,1)

test/bidiag.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,18 @@ Random.seed!(1)
439439
for op in (+, -, *)
440440
@test Array(op(T, T2)) op(Tfull, Tfull2)
441441
end
442+
A = kron(T.dv, T.dv')
443+
@test T * A lmul!(T, copy(A))
444+
@test A * T rmul!(copy(A), T)
442445
end
443446
# test pass-through of mul! for SymTridiagonal*Bidiagonal
444447
TriSym = SymTridiagonal(T.dv, T.ev)
445448
@test Array(TriSym*T) Array(TriSym)*Array(T)
446449
# test pass-through of mul! for AbstractTriangular*Bidiagonal
447450
Tri = UpperTriangular(diagm(1 => T.ev))
448451
Dia = Diagonal(T.dv)
449-
@test Array(Tri*T) Array(Tri)*Array(T)
452+
@test Array(Tri*T) Array(Tri)*Array(T) rmul!(copy(Tri), T)
453+
@test Array(T*Tri) Array(T)*Array(Tri) lmul!(T, copy(Tri))
450454
# test mul! itself for these types
451455
for AA in (Tri, Dia)
452456
for f in (identity, transpose, adjoint)
@@ -459,8 +463,10 @@ Random.seed!(1)
459463
for f in (identity, transpose, adjoint)
460464
C = relty == Int ? rand(float(elty), n, n) : rand(elty, n, n)
461465
B = rand(elty, n, n)
462-
D = copy(C) + 2.0 * Array(T*f(B))
463-
mul!(C, T, f(B), 2.0, 1.0) D
466+
D = C + 2.0 * Array(T*f(B))
467+
@test mul!(C, T, f(B), 2.0, 1.0) D
468+
@test lmul!(T, copy(f(B))) T * f(B)
469+
@test rmul!(copy(f(B)), T) f(B) * T
464470
end
465471

466472
# Issue #31870

0 commit comments

Comments
 (0)