Skip to content

Commit e7010b1

Browse files
committed
Merge pull request #208
2 parents 2e9feff + 511fb0f commit e7010b1

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

src/matrix_multiply.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import Base: *, Ac_mul_B, A_mul_Bc, Ac_mul_Bc, At_mul_B, A_mul_Bt, At_mul_Bt
22
import Base: A_mul_B!, Ac_mul_B!, A_mul_Bc!, Ac_mul_Bc!, At_mul_B!, A_mul_Bt!, At_mul_Bt!
33

4-
import Base.LinAlg: BlasFloat
4+
import Base.LinAlg: BlasFloat, matprod
55

66
const StaticVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticMatrix{<:Any, <:Any, T}}
77

8-
# Idea inspired by https://github.com/JuliaLang/julia/pull/18218
9-
promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero(T1)*zero(T2))
10-
118
# TODO Potentially a loop version for rather large arrays? Or try and figure out inference problems?
129

1310
# Deal with A_mul_Bc, etc...
@@ -60,7 +57,7 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
6057
if length(b) != sa[2]
6158
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))"))
6259
end
63-
T = promote_matprod(Ta, Tb)
60+
T = promote_op(matprod,Ta,Tb)
6461
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
6562
end
6663
end
@@ -78,7 +75,7 @@ end
7875

7976
return quote
8077
@_inline_meta
81-
T = promote_matprod(Ta, Tb)
78+
T = promote_op(matprod,Ta,Tb)
8279
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
8380
end
8481
end
@@ -171,7 +168,7 @@ end
171168

172169
return quote
173170
@_inline_meta
174-
T = promote_matprod(Ta, Tb)
171+
T = promote_op(matprod,Ta,Tb)
175172
@inbounds return similar_type(a, T, $S)(tuple($(exprs...)))
176173
end
177174
end
@@ -190,7 +187,7 @@ end
190187

191188
return quote
192189
@_inline_meta
193-
T = promote_matprod(Ta, Tb)
190+
T = promote_op(matprod,Ta,Tb)
194191

195192
@inbounds $(Expr(:block, exprs_init...))
196193
for j = 2:$(sa[2])
@@ -218,7 +215,7 @@ end
218215

219216
return quote
220217
@_inline_meta
221-
T = promote_matprod(Ta, Tb)
218+
T = promote_op(matprod,Ta,Tb)
222219
$(Expr(:block,
223220
vect_exprs...,
224221
:(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
@@ -234,7 +231,7 @@ end
234231
if sa[2] != 0
235232
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(sub2ind(sa, k, j))]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
236233
else
237-
exprs = [:(zero(promote_matprod(Ta,Tb))) for k = 1:sa[1]]
234+
exprs = [:(zero(promote_op(matprod,Ta,Tb))) for k = 1:sa[1]]
238235
end
239236

240237
return quote

test/matrix_multiply.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
@test isa(x, SVector{2,CartesianIndex{2}})
1212
@test x == @SVector [CartesianIndex((7,5)), CartesianIndex((15,13))]
1313

14+
# block matrices
15+
bm = @SMatrix [m m; m m]
16+
bv = @SVector [v,v]
17+
@test (bm*bv)::SVector{2,SVector{2,Int}} == @SVector [[10,22],[10,22]]
18+
1419
# inner product
1520
@test @inferred(v'*v) === 5
1621

@@ -49,6 +54,13 @@
4954
v = @SVector [1, 2]
5055
@test @inferred(v*m) === @SMatrix [1 2 3 4; 2 4 6 8]
5156

57+
# block matrices
58+
m = @SMatrix [1 2; 3 4]
59+
bm = @SMatrix [m m; m m]
60+
bv = @SVector [v,v]
61+
# Broken only because output turns into a normal array:
62+
@test_broken (bv'*bm)'::SVector{2,SVector{2,Int}} == @SVector [[14,20],[14,20]]
63+
5264
# Outer product
5365
v2 = SVector(1, 2)
5466
v3 = SVector(3, 4)
@@ -87,6 +99,11 @@
8799
n = @MArray [2 3; 4 5]
88100
@test (m*n) == @SMatrix [10 13; 22 29]
89101

102+
# block matrices
103+
bm = @SMatrix [m m; m m]
104+
bm2 = @SMatrix [14 20; 30 44]
105+
@test (bm*bm)::SMatrix{2,2,SMatrix{2,2,Int,4}} == @SMatrix [bm2 bm2; bm2 bm2]
106+
90107
# Alternative methods used between 8 < n <= 14 and n > 14
91108
m_array = rand(1:10, 10, 10)
92109
n_array = rand(1:10, 10, 10)

0 commit comments

Comments
 (0)