Skip to content

Commit 2eba36d

Browse files
committed
structured matrix multiplication pt 1
1 parent a59b6cc commit 2eba36d

File tree

2 files changed

+155
-15
lines changed

2 files changed

+155
-15
lines changed

src/matrix_multiply.jl

Lines changed: 131 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,129 @@ import LinearAlgebra: BlasFloat, matprod, mul!
44
# Manage dispatch of * and mul!
55
# TODO Adjoint? (Inner product?)
66

7-
@inline *(A::StaticMatrix, B::AbstractVector) = _mul(Size(A), A, B)
8-
@inline *(A::StaticMatrix, B::StaticVector) = _mul(Size(A), Size(B), A, B)
9-
@inline *(A::StaticMatrix, B::StaticMatrix) = _mul(Size(A), Size(B), A, B)
10-
@inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B)
7+
const StaticMatMulLike{s1, s2, T} = Union{
8+
StaticMatrix{s1, s2, T},
9+
Symmetric{T, <:StaticMatrix{s1, s2, T}},
10+
Hermitian{T, <:StaticMatrix{s1, s2, T}}}
11+
12+
@inline *(A::StaticMatMulLike, B::AbstractVector) = _mul(Size(A), A, B)
13+
@inline *(A::StaticMatMulLike, B::StaticVector) = _mul(Size(A), Size(B), A, B)
14+
@inline *(A::StaticMatMulLike, B::StaticMatMulLike) = _mul(Size(A), Size(B), A, B)
15+
@inline *(A::StaticVector, B::StaticMatMulLike) = *(reshape(A, Size(Size(A)[1], 1)), B)
1116
@inline *(A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B)
1217
@inline *(A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B)
1318
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
1419
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
1520

21+
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, asym = :a)
22+
return expr_gen(:any)
23+
end
24+
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :a)
25+
return quote
26+
if $(asym).uplo == 'U'
27+
$(expr_gen(:up))
28+
else
29+
$(expr_gen(:lo))
30+
end
31+
end
32+
end
33+
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :a)
34+
return quote
35+
if $(asym).uplo == 'U'
36+
$(expr_gen(:up_herm))
37+
else
38+
$(expr_gen(:lo_herm))
39+
end
40+
end
41+
end
42+
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type{<:StaticMatrix})
43+
return expr_gen(:any, :any)
44+
end
45+
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type)
46+
return gen_by_access(a) do access_a
47+
return quote
48+
return $(gen_by_access(b, :b) do access_b
49+
expr_gen(:any, access_b)
50+
end)
51+
end
52+
end
53+
end
54+
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type)
55+
return gen_by_access(a) do access_a
56+
return quote
57+
if a.uplo == 'U'
58+
return $(gen_by_access(b, :b) do access_b
59+
expr_gen(:up, access_b)
60+
end)
61+
else
62+
return $(gen_by_access(b, :b) do access_b
63+
expr_gen(:lo, access_b)
64+
end)
65+
end
66+
end
67+
end
68+
end
69+
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type)
70+
return gen_by_access(a) do access_a
71+
return quote
72+
if a.uplo == 'U'
73+
return $(gen_by_access(b, :b) do access_b
74+
expr_gen(:up_herm, access_b)
75+
end)
76+
else
77+
return $(gen_by_access(b, :b) do access_b
78+
expr_gen(:lo_herm, access_b)
79+
end)
80+
end
81+
end
82+
end
83+
end
84+
85+
function uplo_access(sa, asym, k, j, uplo)
86+
if uplo == :any
87+
return :($asym[$(LinearIndices(sa)[k, j])])
88+
elseif uplo == :up
89+
if k <= j
90+
return :($asym[$(LinearIndices(sa)[k, j])])
91+
else
92+
return :($asym[$(LinearIndices(sa)[j, k])])
93+
end
94+
elseif uplo == :lo
95+
if j <= k
96+
return :($asym[$(LinearIndices(sa)[k, j])])
97+
else
98+
return :($asym[$(LinearIndices(sa)[j, k])])
99+
end
100+
elseif uplo == :up_herm
101+
if k <= j
102+
return :($asym[$(LinearIndices(sa)[k, j])])
103+
else
104+
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
105+
end
106+
elseif uplo == :lo_herm
107+
if j <= k
108+
return :($asym[$(LinearIndices(sa)[k, j])])
109+
else
110+
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
111+
end
112+
end
113+
end
16114

17115
# Implementations
18116

19-
@generated function _mul(::Size{sa}, a::StaticMatrix{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
117+
function mul_smat_vec_exprs(sa, access_a)
118+
return [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
119+
end
120+
121+
@generated function _mul(::Size{sa}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
20122
if sa[2] != 0
21-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
123+
retexpr = gen_by_access(a) do access_a
124+
exprs = mul_smat_vec_exprs(sa, access_a)
125+
return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
126+
end
22127
else
23128
exprs = [:(zero(T)) for k = 1:sa[1]]
129+
retexpr = :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
24130
end
25131

26132
return quote
@@ -29,28 +135,33 @@ import LinearAlgebra: BlasFloat, matprod, mul!
29135
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))"))
30136
end
31137
T = promote_op(matprod,Ta,Tb)
32-
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
138+
$retexpr
33139
end
34140
end
35141

36-
@generated function _mul(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
142+
@generated function _mul(::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
37143
if sb[1] != sa[2]
38144
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
39145
end
40146

41147
if sa[2] != 0
42-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
148+
retexpr = gen_by_access(a) do access_a
149+
exprs = mul_smat_vec_exprs(sa, access_a)
150+
return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
151+
end
43152
else
44153
exprs = [:(zero(T)) for k = 1:sa[1]]
154+
retexpr = :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
45155
end
46156

47157
return quote
48158
@_inline_meta
49159
T = promote_op(matprod,Ta,Tb)
50-
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
160+
$retexpr
51161
end
52162
end
53163

164+
54165
# outer product
55166
@generated function _mul(::Size{sa}, ::Size{sb}, a::StaticVector{<: Any, Ta},
56167
b::Union{Transpose{Tb, <:StaticVector}, Adjoint{Tb, <:StaticVector}}) where {sa, sb, Ta, Tb}
@@ -64,7 +175,7 @@ end
64175
end
65176
end
66177

67-
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
178+
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
68179
# Heuristic choice for amount of codegen
69180
if sa[1]*sa[2]*sb[2] <= 8*8*8
70181
return quote
@@ -117,27 +228,32 @@ end
117228
end
118229
end
119230

120-
@generated function mul_unrolled(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
231+
@generated function mul_unrolled(::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
121232
if sb[1] != sa[2]
122233
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
123234
end
124235

125236
S = Size(sa[1], sb[2])
126237

127238
if sa[2] != 0
128-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]) for k1 = 1:sa[1], k2 = 1:sb[2]]
239+
retexpr = gen_by_access(a, b) do access_a, access_b
240+
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)),
241+
[:($(uplo_access(sa, :a, k1, j, access_a))*$(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
242+
) for k1 = 1:sa[1], k2 = 1:sb[2]]
243+
return :(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
244+
end
129245
else
130246
exprs = [:(zero(T)) for k1 = 1:sa[1], k2 = 1:sb[2]]
247+
retexpr = :(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
131248
end
132249

133250
return quote
134251
@_inline_meta
135252
T = promote_op(matprod,Ta,Tb)
136-
@inbounds return similar_type(a, T, $S)(tuple($(exprs...)))
253+
$retexpr
137254
end
138255
end
139256

140-
141257
@generated function mul_loop(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
142258
if sb[1] != sa[2]
143259
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))

test/matrix_multiply.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
using StaticArrays, Test, LinearAlgebra
22

3+
mul_wrappers = [
4+
m -> m,
5+
m -> Symmetric(m, :U),
6+
m -> Symmetric(m, :L),
7+
m -> Hermitian(m, :U),
8+
m -> Hermitian(m, :L)]
9+
310
@testset "Matrix multiplication" begin
411
@testset "Matrix-vector" begin
512
m = @SMatrix [1 2; 3 4]
613
v = @SVector [1, 2]
714
v_bad = @SVector [1, 2, 3]
815
@test m*v === @SVector [5, 11]
16+
for wrapper in mul_wrappers
17+
@test (@inferred wrapper(m)*v)::SVector{2} == wrapper(Array(m))*Array(v)
18+
end
919
@test_throws DimensionMismatch m*v_bad
1020
# More complicated eltype inference
1121
v2 = @SVector [CartesianIndex((1,3)), CartesianIndex((3,1))]
@@ -17,6 +27,12 @@ using StaticArrays, Test, LinearAlgebra
1727
bm = @SMatrix [m m; m m]
1828
bv = @SVector [v,v]
1929
@test (bm*bv)::SVector{2,SVector{2,Int}} == @SVector [[10,22],[10,22]]
30+
for wrapper in mul_wrappers
31+
# there may be some problems with inferring the result type of symmetric block matrices
32+
if !isa(wrapper(bm), UpperTriangular)
33+
@test wrapper(bm)*bv == wrapper(Array(bm))*Array(bv)
34+
end
35+
end
2036

2137
# inner product
2238
@test @inferred(v'*v) === 5
@@ -128,6 +144,10 @@ using StaticArrays, Test, LinearAlgebra
128144
@test m*transpose(n) === @SMatrix [8 14; 18 32]
129145
@test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28]
130146

147+
for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
148+
@test (@inferred wrapper_m(m) * wrapper_n(n))::SMatrix{2,2,Int} == wrapper_m(Array(m)) * wrapper_n(Array(n))
149+
end
150+
131151
m = @MMatrix [1 2; 3 4]
132152
n = @MMatrix [2 3; 4 5]
133153
@test (m*n) == @MMatrix [10 13; 22 29]
@@ -289,6 +309,10 @@ using StaticArrays, Test, LinearAlgebra
289309
@test a::MMatrix{2,2,Int,4} == @MMatrix [8 14; 18 32]
290310
mul!(a, transpose(m), transpose(n))
291311
@test a::MMatrix{2,2,Int,4} == @MMatrix [11 19; 16 28]
312+
#=for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
313+
mul!(a, wrapper_m(m), wrapper_n(n))
314+
@test a::MMatrix{2,2,Int,4} == wrapper_m(Array(m))*wrapper_n(Array(n))
315+
end=#
292316

293317
a2 = MArray{Tuple{2,2},Int,2,4}(undef)
294318
mul!(a2, m, n)

0 commit comments

Comments
 (0)