@@ -17,6 +17,13 @@ mul_wrappers = [
17
17
(m -> Transpose (m), " transpo" ),
18
18
(m -> Diagonal (m), " diag " )]
19
19
20
+ mul_wrappers_reduced = [
21
+ (m -> m, " ident " ),
22
+ (m -> Symmetric (m, :U ), " sym-u " ),
23
+ (m -> UpperTriangular (m), " up-tri " ),
24
+ (m -> Transpose (m), " transpo" ),
25
+ (m -> Diagonal (m), " diag " )]
26
+
20
27
for N in [2 , 4 , 8 , 10 , 16 ]
21
28
22
29
matvecstr = @sprintf (" mat-vec %2d" , N)
@@ -41,7 +48,7 @@ for N in [2, 4, 8, 10, 16]
41
48
thrown = true
42
49
end
43
50
if ! thrown
44
- suite[matvecstr][wrapper_name] = @benchmarkable $ (wrapper_a (A)) * $ bv
51
+ suite[matvecstr][wrapper_name] = @benchmarkable $ (Ref ( wrapper_a (A)))[] * $ ( Ref (bv))[]
45
52
end
46
53
end
47
54
@@ -53,7 +60,7 @@ for N in [2, 4, 8, 10, 16]
53
60
thrown = true
54
61
end
55
62
if ! thrown
56
- suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $ (wrapper_a (A)) * $ (wrapper_b (B))
63
+ suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $ (Ref ( wrapper_a (A)))[] * $ (Ref ( wrapper_b (B)))[]
57
64
end
58
65
end
59
66
@@ -68,7 +75,7 @@ for N in [2, 4, 8, 10, 16]
68
75
thrown = true
69
76
end
70
77
if ! thrown
71
- suite[matvec_mut_str][wrapper_name] = @benchmarkable mul! ($ cv, $ (wrapper_a (A)), $ bv )
78
+ suite[matvec_mut_str][wrapper_name] = @benchmarkable mul! ($ cv, $ (Ref ( wrapper_a (A)))[] , $ ( Ref (bv))[] )
72
79
end
73
80
end
74
81
@@ -80,7 +87,7 @@ for N in [2, 4, 8, 10, 16]
80
87
thrown = true
81
88
end
82
89
if ! thrown
83
- suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul! ($ C, $ (wrapper_a (A)), $ (wrapper_b (B)))
90
+ suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul! ($ C, $ (Ref ( wrapper_a (A)))[] , $ (Ref ( wrapper_b (B)))[] )
84
91
end
85
92
end
86
93
end
@@ -111,3 +118,94 @@ function judge_results(m1, m2)
111
118
end
112
119
return results
113
120
end
121
+
122
+ function generic_mul (size_a, size_b, a, b)
123
+ return invoke (* , Tuple{StaticArrays. _unstatic_array (typeof (a)),StaticArrays. _unstatic_array (typeof (b))}, a, b)
124
+ end
125
+
126
+ function full_benchmark (mul_wrappers, size_iter = 1 : 4 , T = Float64)
127
+ suite_full = BenchmarkGroup ()
128
+ for N in size_iter
129
+ for M in size_iter
130
+ a = randn (SMatrix{N,M,T})
131
+ wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1 ]]
132
+ sa = Size (a)
133
+ for K in size_iter
134
+ b = randn (SMatrix{M,K,T})
135
+ wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1 ]]
136
+ sb = Size (b)
137
+ for (w_a, w_a_name) in wrappers_a
138
+ for (w_b, w_b_name) in wrappers_b
139
+ cur_str = @sprintf (" mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
140
+ suite_full[cur_str] = @benchmarkable generic_mul ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
141
+ cur_str = @sprintf (" mat-mat %s %s default (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
142
+ suite_full[cur_str] = @benchmarkable StaticArrays. _mul ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
143
+ cur_str = @sprintf (" mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
144
+ suite_full[cur_str] = @benchmarkable StaticArrays. mul_unrolled ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
145
+ if w_a_name != " diag " && w_b_name != " diag "
146
+ cur_str = @sprintf (" mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
147
+ suite_full[cur_str] = @benchmarkable StaticArrays. mul_unrolled_chunks ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
148
+ end
149
+ if w_a_name == " ident " && w_b_name == " ident "
150
+ cur_str = @sprintf (" mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
151
+ suite_full[cur_str] = @benchmarkable StaticArrays. mul_loop ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
152
+ end
153
+ end
154
+ end
155
+ end
156
+ end
157
+ end
158
+ results = run (suite_full, verbose = true )
159
+ results_median = map (collect (results)) do res
160
+ return (res[1 ], median (res[2 ]). time)
161
+ end
162
+ return results_median
163
+ end
164
+
165
+ function judge_this (new_time, old_time, tol, w_a_name, w_b_name, N, M, K, which)
166
+ if new_time* tol < old_time
167
+ msg = @sprintf (" better for %s %s (%2d, %2d) x (%2d, %2d): %s" , w_a_name, w_b_name, N, M, M, K, which)
168
+ println (msg)
169
+ println (" >> " , new_time, " | " , old_time)
170
+ end
171
+ end
172
+
173
+ function pick_best (results, mul_wrappers, size_iter; tol = 1.2 )
174
+ for N in size_iter
175
+ for M in size_iter
176
+ wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1 ]]
177
+ for K in size_iter
178
+ wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1 ]]
179
+ for (w_a, w_a_name) in wrappers_a
180
+ for (w_b, w_b_name) in wrappers_b
181
+ cur_default = @sprintf (" mat-mat %s %s default (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
182
+ default_time = results[cur_default]
183
+
184
+ cur_generic = @sprintf (" mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
185
+ generic_time = results[cur_generic]
186
+ judge_this (generic_time, default_time, tol, w_a_name, w_b_name, N, M, K, " generic" )
187
+
188
+ cur_unrolled = @sprintf (" mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
189
+ unrolled_time = results[cur_unrolled]
190
+ judge_this (unrolled_time, default_time, tol, w_a_name, w_b_name, N, M, K, " unrolled" )
191
+
192
+ if w_a_name != " diag " && w_b_name != " diag "
193
+ cur_chunks = @sprintf (" mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
194
+ chunk_time = results[cur_chunks]
195
+ judge_this (chunk_time, default_time, tol, w_a_name, w_b_name, N, M, K, " chunks" )
196
+ end
197
+ if w_a_name == " ident " && w_b_name == " ident "
198
+ cur_loop = @sprintf (" mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
199
+ loop_time = results[cur_loop]
200
+ judge_this (loop_time, default_time, tol, w_a_name, w_b_name, N, M, K, " loop" )
201
+ end
202
+ end
203
+ end
204
+ end
205
+ end
206
+ end
207
+ end
208
+
209
+ function run_1 ()
210
+ return full_benchmark (mul_wrappers_reduced, [2 , 3 , 4 , 5 , 8 , 9 , 14 , 16 ])
211
+ end
0 commit comments