Skip to content

Commit 836f874

Browse files
authored
Merge pull request #140 from JuliaArrays/StaticMatrix-Vector-matvec
Reinstate StaticMatrix * AbstractVector -> StaticVector
2 parents 22b9812 + e7309a5 commit 836f874

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

src/matrix_multiply.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
3131
# Manage dispatch of * and A_mul_B!
3232
# TODO RowVector? (Inner product?)
3333

34+
@inline *(A::StaticMatrix, B::AbstractVector) = _A_mul_B(Size(A), A, B)
3435
@inline *(A::StaticMatrix, B::StaticVector) = _A_mul_B(Size(A), Size(B), A, B)
3536
@inline *(A::StaticMatrix, B::StaticMatrix) = _A_mul_B(Size(A), Size(B), A, B)
3637
@inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B)
@@ -45,6 +46,23 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
4546

4647
# Implementations
4748

49+
@generated function _A_mul_B(::Size{sa}, a::StaticMatrix{Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
50+
if sa[2] != 0
51+
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(sub2ind(sa, k, j))]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
52+
else
53+
exprs = [:(zero(T)) for k = 1:sa[1]]
54+
end
55+
56+
return quote
57+
@_inline_meta
58+
if length(b) != sa[2]
59+
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))"))
60+
end
61+
T = promote_matprod(Ta, Tb)
62+
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
63+
end
64+
end
65+
4866
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticMatrix{Ta}, b::StaticVector{Tb}) where {sa, sb, Ta, Tb}
4967
if sb[1] != sa[2]
5068
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
@@ -63,8 +81,6 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
6381
end
6482
end
6583

66-
# TODO: I removed StaticMatrix * AbstractVector. Reinstate?
67-
6884
# outer product
6985
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{Ta}, b::RowVector{Tb, <:StaticVector}) where {sa, sb, Ta, Tb}
7086
newsize = (sa[1], sb[2])

test/matrix_multiply.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@test @inferred(v*v') === @SMatrix [1 2; 2 4]
1717

1818
v3 = [1, 2]
19-
@test_broken m*v3 === @SVector [5, 11]
19+
@test m*v3 === @SVector [5, 11]
2020

2121
m2 = @MMatrix [1 2; 3 4]
2222
v4 = @MVector [1, 2]
@@ -32,11 +32,11 @@
3232

3333
m5 = @SMatrix [1.0 2.0; 3.0 4.0]
3434
v7 = [1.0, 2.0]
35-
@test_broken (m5*v7)::SVector @SVector [5.0, 11.0]
35+
@test (m5*v7)::SVector @SVector [5.0, 11.0]
3636

3737
m6 = @SMatrix Float32[1.0 2.0; 3.0 4.0]
3838
v8 = Float64[1.0, 2.0]
39-
@test_broken (m6*v8)::SVector{2,Float64} @SVector [5.0, 11.0]
39+
@test (m6*v8)::SVector{2,Float64} @SVector [5.0, 11.0]
4040

4141
end
4242

0 commit comments

Comments
 (0)