Skip to content

Commit 56c9ca5

Browse files
committed
more tests for multiplication
1 parent 6eaf1e1 commit 56c9ca5

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

src/matrix_multiply.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ end
218218

219219
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
220220
# Heuristic choice for amount of codegen
221-
if sa[1]*sa[2]*sb[2] <= 8*8*8
221+
if sa[1]*sa[2]*sb[2] <= 8*8*8 || !(a <: StaticMatrix) || !(b <: StaticMatrix)
222222
return quote
223223
@_inline_meta
224224
return mul_unrolled(Sa, Sb, a, b)

test/matrix_multiply.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,26 @@ mul_wrappers = [
146146
@test m*transpose(n) === @SMatrix [8 14; 18 32]
147147
@test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28]
148148

149-
for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
150-
wm = wrapper_m(m)
151-
wn = wrapper_n(n)
152-
res_structure = StaticArrays.mul_result_structure(wm, wn)
153-
expected_type = if res_structure == identity
154-
SMatrix{2,2,Int,4}
155-
elseif res_structure == LowerTriangular
156-
LowerTriangular{Int,SMatrix{2,2,Int,4}}
157-
elseif res_structure == UpperTriangular
158-
UpperTriangular{Int,SMatrix{2,2,Int,4}}
159-
else
160-
error("Unknown structure: ", res_structure)
149+
# check different sizes because there are multiple implementations for matrices of different sizes
150+
for (mm, nn) in [
151+
(m, n),
152+
#(SMatrix{10, 10}(collect(1:100)), SMatrix{10, 10}(collect(1:100))),
153+
(SMatrix{15, 15}(collect(1:225)), SMatrix{15, 15}(collect(1:225)))]
154+
for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
155+
wm = wrapper_m(mm)
156+
wn = wrapper_n(nn)
157+
res_structure = StaticArrays.mul_result_structure(wm, wn)
158+
expected_type = if res_structure == identity
159+
typeof(mm)
160+
elseif res_structure == LowerTriangular
161+
LowerTriangular{Int,typeof(mm)}
162+
elseif res_structure == UpperTriangular
163+
UpperTriangular{Int,typeof(mm)}
164+
else
165+
error("Unknown structure: ", res_structure)
166+
end
167+
@test (@inferred wm * wn)::expected_type == wrapper_m(Array(mm)) * wrapper_n(Array(nn))
161168
end
162-
@test (@inferred wm * wn)::expected_type == wrapper_m(Array(m)) * wrapper_n(Array(n))
163169
end
164170

165171
m = @MMatrix [1 2; 3 4]

0 commit comments

Comments
 (0)