Skip to content

Commit 710cf28

Browse files
authored
Diagonal: Check dimensions when multiplying with vector (#36841)
1 parent f130d9b commit 710cf28

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,22 @@ end
174174
(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
175175
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
176176
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
177-
(*)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .* Db.diag)
178-
(*)(D::Diagonal, V::AbstractVector) = D.diag .* V
177+
178+
function (*)(Da::Diagonal, Db::Diagonal)
179+
nDa, mDb = size(Da, 2), size(Db, 1)
180+
if nDa != mDb
181+
throw(DimensionMismatch("second dimension of Da, $nDa, does not match first dimension of Db, $mDb"))
182+
end
183+
return Diagonal(Da.diag .* Db.diag)
184+
end
185+
186+
function (*)(D::Diagonal, V::AbstractVector)
187+
nD = size(D, 2)
188+
if nD != length(V)
189+
throw(DimensionMismatch("second dimension of D, $nD, does not match length of V, $(length(V))"))
190+
end
191+
return D.diag .* V
192+
end
179193

180194
(*)(A::AbstractTriangular, D::Diagonal) =
181195
rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A), D)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,15 @@ end
644644
@test yt*D*y == (yt*D)*y == (yt*A)*y
645645
end
646646

647+
@testset "Multiplication of single element Diagonal (#36746)" begin
648+
@test_throws DimensionMismatch Diagonal(randn(1)) * randn(5)
649+
@test_throws DimensionMismatch Diagonal(randn(1)) * Diagonal(randn(3, 3))
650+
A = [1 0; 0 2]
651+
v = [3, 4]
652+
@test Diagonal(A) * v == A * v
653+
@test Diagonal(A) * Diagonal(A) == A * A
654+
end
655+
647656
@testset "Triangular division by Diagonal #27989" begin
648657
K = 5
649658
for elty in (Float32, Float64, ComplexF32, ComplexF64)

0 commit comments

Comments
 (0)