Skip to content

Commit fd8cd5c

Browse files
committed
modified matrix multiplication heuristics
1 parent 0ff9b55 commit fd8cd5c

File tree

2 files changed

+130
-14
lines changed

2 files changed

+130
-14
lines changed

benchmark/bench_mat_mul.jl

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ mul_wrappers = [
1717
(m -> Transpose(m), "transpo"),
1818
(m -> Diagonal(m), "diag ")]
1919

20+
mul_wrappers_reduced = [
21+
(m -> m, "ident "),
22+
(m -> Symmetric(m, :U), "sym-u "),
23+
(m -> UpperTriangular(m), "up-tri "),
24+
(m -> Transpose(m), "transpo"),
25+
(m -> Diagonal(m), "diag ")]
26+
2027
for N in [2, 4, 8, 10, 16]
2128

2229
matvecstr = @sprintf("mat-vec %2d", N)
@@ -41,7 +48,7 @@ for N in [2, 4, 8, 10, 16]
4148
thrown = true
4249
end
4350
if !thrown
44-
suite[matvecstr][wrapper_name] = @benchmarkable $(wrapper_a(A)) * $bv
51+
suite[matvecstr][wrapper_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(bv))[]
4552
end
4653
end
4754

@@ -53,7 +60,7 @@ for N in [2, 4, 8, 10, 16]
5360
thrown = true
5461
end
5562
if !thrown
56-
suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $(wrapper_a(A)) * $(wrapper_b(B))
63+
suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(wrapper_b(B)))[]
5764
end
5865
end
5966

@@ -68,7 +75,7 @@ for N in [2, 4, 8, 10, 16]
6875
thrown = true
6976
end
7077
if !thrown
71-
suite[matvec_mut_str][wrapper_name] = @benchmarkable mul!($cv, $(wrapper_a(A)), $bv)
78+
suite[matvec_mut_str][wrapper_name] = @benchmarkable mul!($cv, $(Ref(wrapper_a(A)))[], $(Ref(bv))[])
7279
end
7380
end
7481

@@ -80,7 +87,7 @@ for N in [2, 4, 8, 10, 16]
8087
thrown = true
8188
end
8289
if !thrown
83-
suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul!($C, $(wrapper_a(A)), $(wrapper_b(B)))
90+
suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul!($C, $(Ref(wrapper_a(A)))[], $(Ref(wrapper_b(B)))[])
8491
end
8592
end
8693
end
@@ -111,3 +118,94 @@ function judge_results(m1, m2)
111118
end
112119
return results
113120
end
121+
122+
function generic_mul(size_a, size_b, a, b)
123+
return invoke(*, Tuple{StaticArrays._unstatic_array(typeof(a)),StaticArrays._unstatic_array(typeof(b))}, a, b)
124+
end
125+
126+
function full_benchmark(mul_wrappers, size_iter = 1:4, T = Float64)
127+
suite_full = BenchmarkGroup()
128+
for N in size_iter
129+
for M in size_iter
130+
a = randn(SMatrix{N,M,T})
131+
wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]]
132+
sa = Size(a)
133+
for K in size_iter
134+
b = randn(SMatrix{M,K,T})
135+
wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]]
136+
sb = Size(b)
137+
for (w_a, w_a_name) in wrappers_a
138+
for (w_b, w_b_name) in wrappers_b
139+
cur_str = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
140+
suite_full[cur_str] = @benchmarkable generic_mul($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
141+
cur_str = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
142+
suite_full[cur_str] = @benchmarkable StaticArrays._mul($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
143+
cur_str = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
144+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
145+
if w_a_name != "diag " && w_b_name != "diag "
146+
cur_str = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
147+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled_chunks($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
148+
end
149+
if w_a_name == "ident " && w_b_name == "ident "
150+
cur_str = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
151+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_loop($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
152+
end
153+
end
154+
end
155+
end
156+
end
157+
end
158+
results = run(suite_full, verbose = true)
159+
results_median = map(collect(results)) do res
160+
return (res[1], median(res[2]).time)
161+
end
162+
return results_median
163+
end
164+
165+
function judge_this(new_time, old_time, tol, w_a_name, w_b_name, N, M, K, which)
166+
if new_time*tol < old_time
167+
msg = @sprintf("better for %s %s (%2d, %2d) x (%2d, %2d): %s", w_a_name, w_b_name, N, M, M, K, which)
168+
println(msg)
169+
println(">> ", new_time, " | ", old_time)
170+
end
171+
end
172+
173+
function pick_best(results, mul_wrappers, size_iter; tol = 1.2)
174+
for N in size_iter
175+
for M in size_iter
176+
wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]]
177+
for K in size_iter
178+
wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]]
179+
for (w_a, w_a_name) in wrappers_a
180+
for (w_b, w_b_name) in wrappers_b
181+
cur_default = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
182+
default_time = results[cur_default]
183+
184+
cur_generic = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
185+
generic_time = results[cur_generic]
186+
judge_this(generic_time, default_time, tol, w_a_name, w_b_name, N, M, K, "generic")
187+
188+
cur_unrolled = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
189+
unrolled_time = results[cur_unrolled]
190+
judge_this(unrolled_time, default_time, tol, w_a_name, w_b_name, N, M, K, "unrolled")
191+
192+
if w_a_name != "diag " && w_b_name != "diag "
193+
cur_chunks = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
194+
chunk_time = results[cur_chunks]
195+
judge_this(chunk_time, default_time, tol, w_a_name, w_b_name, N, M, K, "chunks")
196+
end
197+
if w_a_name == "ident " && w_b_name == "ident "
198+
cur_loop = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
199+
loop_time = results[cur_loop]
200+
judge_this(loop_time, default_time, tol, w_a_name, w_b_name, N, M, K, "loop")
201+
end
202+
end
203+
end
204+
end
205+
end
206+
end
207+
end
208+
209+
function run_1()
210+
return full_benchmark(mul_wrappers_reduced, [2, 3, 4, 5, 8, 9, 14, 16])
211+
end

src/matrix_multiply.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -389,31 +389,33 @@ function combine_products(expr_list)
389389
end
390390

391391
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
392+
S = Size(sa[1], sb[2])
392393
# Heuristic choice for amount of codegen
393-
a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 2 : 1
394-
b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 2 : 1
395-
ab_tri_mul = (a == 2 && b == 2) ? 2 : 1
396-
if sa[1]*sa[2]*sb[2] <= 8*8*8*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal
394+
a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 4 : 1
395+
b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 4 : 1
396+
ab_tri_mul = (a == 4 && b == 4) ? 2 : 1
397+
if a <: StaticMatrix && b <: StaticMatrix
398+
# Julia unrolls these loops pretty well
397399
return quote
398400
@_inline_meta
399-
return mul_unrolled(Sa, Sb, a, b)
401+
return mul_loop(Sa, Sb, a, b)
400402
end
401-
elseif (sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14) || !(a <: StaticMatrix) || !(b <: StaticMatrix)
403+
elseif sa[1]*sa[2]*sb[2] <= 4*8*8*8*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal
402404
return quote
403405
@_inline_meta
404-
return mul_unrolled_chunks(Sa, Sb, a, b)
406+
return mul_unrolled(Sa, Sb, a, b)
405407
end
406-
elseif a <: StaticMatrix && b <:StaticMatrix
408+
elseif (sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14) || !(a <: StaticMatrix) || !(b <: StaticMatrix)
407409
return quote
408410
@_inline_meta
409-
return mul_loop(Sa, Sb, a, b)
411+
return mul_unrolled_chunks(Sa, Sb, a, b)
410412
end
411413
else
412414
# we don't have any special code for handling this case so let's fall back to
413415
# the generic implementation of matrix multiplication
414416
return quote
415417
@_inline_meta
416-
return invoke(*, Tuple{$(_unstatic_array(a)),$(_unstatic_array(b))}, a, b)
418+
return mul_generic(Sa, Sb, a, b)
417419
end
418420
end
419421
end
@@ -468,6 +470,22 @@ end
468470
end
469471
end
470472

473+
@generated function mul_generic(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
474+
if sb[1] != sa[2]
475+
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
476+
end
477+
478+
S = Size(sa[1], sb[2])
479+
480+
return quote
481+
@_inline_meta
482+
T = promote_op(matprod, Ta, Tb)
483+
a = mul_parent(wrapped_a)
484+
b = mul_parent(wrapped_b)
485+
return (mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(invoke(*, Tuple{$(_unstatic_array(a)),$(_unstatic_array(b))}, a, b)))
486+
end
487+
end
488+
471489
# Concatenate a series of matrix-vector multiplications
472490
# Each function is N^2 not N^3 - aids in compile time.
473491
@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}

0 commit comments

Comments
 (0)