Skip to content

Commit 8e11852

Browse files
committed
blas mul! fix and matrix multiplication benchmarks
1 parent 3e96bbc commit 8e11852

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

benchmark/bench_mat_mul.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
using StaticArrays
2+
using BenchmarkTools
3+
using LinearAlgebra
4+
using Printf
5+
6+
suite = BenchmarkGroup()
7+
8+
mul_wrappers = [
9+
(m -> m, "ident "),
10+
(m -> Symmetric(m, :U), "sym-u "),
11+
(m -> Hermitian(m, :U), "herm-u "),
12+
(m -> UpperTriangular(m), "up-tri "),
13+
(m -> LowerTriangular(m), "lo-tri "),
14+
(m -> UnitUpperTriangular(m), "uup-tri"),
15+
(m -> UnitLowerTriangular(m), "ulo-tri"),
16+
(m -> Adjoint(m), "adjoint"),
17+
(m -> Transpose(m), "transpo")]
18+
19+
for N in [2, 4, 8, 10, 16]
20+
21+
matvecstr = @sprintf("mat-vec %2d", N)
22+
matmatstr = @sprintf("mat-mat %2d", N)
23+
matvec_mut_str = @sprintf("mat-vec! %2d", N)
24+
matmat_mut_str = @sprintf("mat-mat! %2d", N)
25+
26+
suite[matvecstr] = BenchmarkGroup()
27+
suite[matmatstr] = BenchmarkGroup()
28+
suite[matvec_mut_str] = BenchmarkGroup()
29+
suite[matmat_mut_str] = BenchmarkGroup()
30+
31+
32+
A = randn(SMatrix{N,N,Float64})
33+
B = randn(SMatrix{N,N,Float64})
34+
bv = randn(SVector{N,Float64})
35+
for (wrapper_a, wrapper_name) in mul_wrappers
36+
thrown = false
37+
try
38+
wrapper_a(A) * bv
39+
catch e
40+
thrown = true
41+
end
42+
if !thrown
43+
suite[matvecstr][wrapper_name] = @benchmarkable $(wrapper_a(A)) * $bv
44+
end
45+
end
46+
47+
for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers
48+
thrown = false
49+
try
50+
wrapper_a(A) * wrapper_b(B)
51+
catch e
52+
thrown = true
53+
end
54+
if !thrown
55+
suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $(wrapper_a(A)) * $(wrapper_b(B))
56+
end
57+
end
58+
59+
C = randn(MMatrix{N,N,Float64})
60+
cv = randn(MVector{N,Float64})
61+
62+
for (wrapper_a, wrapper_name) in mul_wrappers
63+
thrown = false
64+
try
65+
mul!(cv, wrapper_a(A), bv)
66+
catch e
67+
thrown = true
68+
end
69+
if !thrown
70+
suite[matvec_mut_str][wrapper_name] = @benchmarkable mul!($cv, $(wrapper_a(A)), $bv)
71+
end
72+
end
73+
74+
for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers
75+
thrown = false
76+
try
77+
wrapper_a(A) * wrapper_b(B)
78+
catch e
79+
thrown = true
80+
end
81+
if !thrown
82+
suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul!($C, $(wrapper_a(A)), $(wrapper_b(B)))
83+
end
84+
end
85+
end
86+
87+
function run_and_save(fname, make_params = true)
88+
if make_params
89+
tune!(suite)
90+
BenchmarkTools.save("params.json", params(suite))
91+
else
92+
loadparams!(suite, BenchmarkTools.load("params.json")[1], :evals, :samples)
93+
end
94+
results = run(suite, verbose = true)
95+
BenchmarkTools.save(fname, results)
96+
end
97+
98+
function judge_results(m1, m2)
99+
results = Any[]
100+
for key1 in keys(m1)
101+
if !haskey(m2, key1)
102+
continue
103+
end
104+
for key2 in keys(m1[key1])
105+
if !haskey(m2[key1], key2)
106+
continue
107+
end
108+
push!(results, (key1, key2, judge(median(m1[key1][key2]), median(m2[key1][key2]))))
109+
end
110+
end
111+
return results
112+
end

src/matrix_multiply_add.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ end
185185
a::StaticMatMulLike, b::StaticMatMulLike,
186186
_add::MulAddMul) where {sa, sb, sc}
187187
Ta,Tb,Tc = eltype(a), eltype(b), eltype(c)
188-
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
188+
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat && a <: Union{StaticMatrix,Transpose} && b <: Union{StaticMatrix,Transpose}
189189

190190
mult_dim = multiplied_dimension(a,b)
191191
if mult_dim < 4*4*4
@@ -316,6 +316,8 @@ end
316316

317317
@inline _get_raw_data(A::SizedArray) = A.data
318318
@inline _get_raw_data(A::StaticArray) = A
319+
# we need something heap-allocated to make sure BLAS calls are safe
320+
@inline _get_raw_data(A::SArray) = MArray(A)
319321

320322
function mul_blas!(::TSize{<:Any,:any}, c::StaticMatrix,
321323
Sa::Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}}, Sb::Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}},

test/matrix_multiply_add.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,4 +243,5 @@ end
243243
@testset "Testing different wrappers" begin
244244
test_wrappers_for_size(2, true)
245245
test_wrappers_for_size(8, false)
246+
test_wrappers_for_size(16, false)
246247
end

0 commit comments

Comments
 (0)