Skip to content

Commit 11f371d

Browse files
authored
Merge pull request #324 from JuliaArrays/fix-nonsquare-matrix-mult
Fix sizes in non-square medium sized matrix multiply
2 parents 8aacd2b + cae40b8 commit 11f371d

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/matrix_multiply.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,11 @@ end
207207
S = Size(sa[1], sb[2])
208208

209209
# Do a custom b[:, k2] to return a SVector (an isbits type) rather than (possibly) a mutable type. Avoids allocation == faster
210-
tmp_type_in = :(SVector{$(sa[1]), T})
211-
tmp_type_out = :(SVector{$(sb[1]), T})
212-
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(Size(a), Size($(sa[1])), a, $(Expr(:call, tmp_type_in, [Expr(:ref, :b, sub2ind(S, i, k2)) for i = 1:sb[1]]...)))::$tmp_type_out) for k2 = 1:sb[2]]
210+
tmp_type_in = :(SVector{$(sb[1]), T})
211+
tmp_type_out = :(SVector{$(sa[1]), T})
212+
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(Size(a), Size($(sb[1])), a,
213+
$(Expr(:call, tmp_type_in, [Expr(:ref, :b, sub2ind(sb, i, k2)) for i = 1:sb[1]]...)))::$tmp_type_out)
214+
for k2 = 1:sb[2]]
213215

214216
exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
215217

test/matrix_multiply.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@
121121
n2 = SMatrix{16,16}(n_array2)
122122
@test m2*n2 === SMatrix{16,16}(a_array2)
123123

124+
# Non-square version
125+
m_array3 = rand(1:10, 9, 10)
126+
n_array3 = rand(1:10, 10, 11)
127+
a_array3 = m_array3*n_array3
128+
m3 = SMatrix{9,10}(m_array3)
129+
n3 = SMatrix{10,11}(n_array3)
130+
@test m3*n3 === SMatrix{9,11}(a_array3)
131+
124132
# Mutating types follow different behaviour
125133
m_array = rand(1:10, 10, 10)
126134
n_array = rand(1:10, 10, 10)

0 commit comments

Comments
 (0)