Skip to content

Commit 03b2ce8

Browse files
authored
Merge pull request #555 from zygmuntszpak/dispatch_rules
Adds dispatch rules for column vector times transpose(row vector)
2 parents 53a19c3 + 1ebd02b commit 03b2ce8

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/matrix_multiply.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1010
@inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B)
1111
@inline *(A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B)
1212
@inline *(A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B)
13+
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
14+
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
1315

1416
@inline mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVector) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
1517
@inline mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticMatrix) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
1618
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::StaticMatrix) = mul!(dest, reshape(A, Size(Size(A)[1], 1)), B)
1719
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
1820
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
19-
2021
#@inline *{TA<:LinearAlgebra.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
2122

2223

24+
2325
# Implementations
2426

2527
@generated function _mul(::Size{sa}, a::StaticMatrix{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}

test/matrix_multiply.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,41 @@ using StaticArrays, Test, LinearAlgebra
7777
@test v4 * transpose(v5) === @SMatrix [3+0im 4+0im; 6+0im 8+0im]
7878
end
7979

80+
@testset "Column vector-vector" begin
81+
cv_array = rand(4,1)
82+
rv_array = rand(4)
83+
a_array = cv_array * rv_array'
84+
85+
cv = SMatrix{4,1}(cv_array)
86+
rv = SVector{4}(rv_array)
87+
@test (cv*adjoint(rv))::SMatrix a_array
88+
89+
cv = MMatrix{4,1}(cv_array)
90+
rv = MVector{4}(rv_array)
91+
@test (cv*adjoint(rv))::SMatrix a_array
92+
93+
cv = SMatrix{4,1}(cv_array)
94+
rv = SVector{4}(rv_array)
95+
@test (cv*transpose(rv))::SMatrix a_array
96+
97+
cv = MMatrix{4,1}(cv_array)
98+
rv = MVector{4}(rv_array)
99+
@test (cv*transpose(rv))::SMatrix a_array
100+
101+
cv_bad = @SMatrix rand(4,2)
102+
rv = @SVector rand(4)
103+
@test_throws DimensionMismatch cv_bad*transpose(rv)
104+
@test_throws DimensionMismatch cv_bad*adjoint(rv)
105+
@test_throws DimensionMismatch cv_bad*rv
106+
107+
cv_bad = @MMatrix rand(4,2)
108+
rv = @MVector rand(4)
109+
110+
@test_throws DimensionMismatch cv_bad*transpose(rv)
111+
@test_throws DimensionMismatch cv_bad*adjoint(rv)
112+
@test_throws DimensionMismatch cv_bad*rv
113+
end
114+
80115
@testset "Matrix-matrix" begin
81116
m = @SMatrix [1 2; 3 4]
82117
n = @SMatrix [2 3; 4 5]

0 commit comments

Comments
 (0)