Skip to content

Commit de9c371

Browse files
authored
check lengths in covector-vector products (#36679)
1 parent 0299027 commit de9c371

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,23 @@ Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...)
267267

268268
## multiplication *
269269

270+
function _dot_nonrecursive(u, v)
271+
lu = length(u)
272+
if lu != length(v)
273+
throw(DimensionMismatch("first array has length $(lu) which does not match the length of the second, $(length(v))."))
274+
end
275+
if lu == 0
276+
zero(eltype(u)) * zero(eltype(v))
277+
else
278+
sum(uu*vv for (uu, vv) in zip(u, v))
279+
end
280+
end
281+
270282
# Adjoint/Transpose-vector * vector
271-
*(u::AdjointAbsVec{T}, v::AbstractVector{T}) where {T<:Number} = dot(u.parent, v)
283+
*(u::AdjointAbsVec{<:Number}, v::AbstractVector{<:Number}) = dot(u.parent, v)
272284
*(u::TransposeAbsVec{T}, v::AbstractVector{T}) where {T<:Real} = dot(u.parent, v)
273-
*(u::AdjOrTransAbsVec, v::AbstractVector) = sum(uu*vv for (uu, vv) in zip(u, v))
285+
*(u::AdjOrTransAbsVec, v::AbstractVector) = _dot_nonrecursive(u, v)
286+
274287

275288
# vector * Adjoint/Transpose-vector
276289
*(u::AbstractVector, v::AdjOrTransAbsVec) = broadcast(*, u, v)
@@ -281,14 +294,10 @@ Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...)
281294

282295
# AdjOrTransAbsVec{<:Any,<:AdjOrTransAbsVec} is a lazy conj vectors
283296
# We need to expand the combinations to avoid ambiguities
284-
(*)(u::TransposeAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) =
285-
sum(uu*vv for (uu, vv) in zip(u, v))
286-
(*)(u::AdjointAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) =
287-
sum(uu*vv for (uu, vv) in zip(u, v))
288-
(*)(u::TransposeAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) =
289-
sum(uu*vv for (uu, vv) in zip(u, v))
290-
(*)(u::AdjointAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) =
291-
sum(uu*vv for (uu, vv) in zip(u, v))
297+
(*)(u::TransposeAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) = _dot_nonrecursive(u, v)
298+
(*)(u::AdjointAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) = _dot_nonrecursive(u, v)
299+
(*)(u::TransposeAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) = _dot_nonrecursive(u, v)
300+
(*)(u::AdjointAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) = _dot_nonrecursive(u, v)
292301

293302
## pseudoinversion
294303
pinv(v::AdjointAbsVec, tol::Real = 0) = pinv(v.parent, tol).parent

stdlib/LinearAlgebra/test/adjtrans.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,4 +550,11 @@ end
550550
@test conj(transpose(hermitian)) === hermitian
551551
end
552552

553+
@testset "empty and mismatched lengths" begin
554+
# issue 36678
555+
@test_throws DimensionMismatch [1, 2]' * [1,2,3]
556+
@test Int[]' * Int[] == 0
557+
@test transpose(Int[]) * Int[] == 0
558+
end
559+
553560
end # module TestAdjointTranspose

0 commit comments

Comments
 (0)