Skip to content

Commit f41d74c

Browse files
committed
some work on in-place structured multiplication
1 parent f39afe0 commit f41d74c

File tree

4 files changed

+175
-73
lines changed

4 files changed

+175
-73
lines changed

src/matrix_multiply.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,17 @@ end
301301
end
302302
end
303303

304-
305304
# outer product
306305
@generated function _mul(::Size{sa}, ::Size{sb}, a::StaticVector{<: Any, Ta},
307306
b::Union{Transpose{Tb, <:StaticVector}, Adjoint{Tb, <:StaticVector}}) where {sa, sb, Ta, Tb}
308307
newsize = (sa[1], sb[2])
309-
exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]]
310-
308+
conjugate_b = b <: Adjoint
309+
if conjugate_b
310+
exprs = [:(a[$i] * adjoint(b[$j])) for i = 1:sa[1], j = 1:sb[2]]
311+
else
312+
exprs = [:(a[$i] * transpose(b[$j])) for i = 1:sa[1], j = 1:sb[2]]
313+
end
314+
311315
return quote
312316
@_inline_meta
313317
T = promote_op(*, Ta, Tb)
@@ -327,7 +331,7 @@ end
327331
@_inline_meta
328332
return mul_unrolled(Sa, Sb, a, b)
329333
end
330-
elseif a <: StaticMatrix && b <:StaticMatrix && sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
334+
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
331335
return quote
332336
@_inline_meta
333337
return mul_unrolled_chunks(Sa, Sb, a, b)
@@ -400,7 +404,7 @@ end
400404

401405
# Concatenate a series of matrix-vector multiplications
402406
# Each function is N^2 not N^3 - aids in compile time.
403-
@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
407+
@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
404408
if sb[1] != sa[2]
405409
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
406410
end
@@ -410,19 +414,27 @@ end
410414
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster
411415
tmp_type_in = :(SVector{$(sb[1]), T})
412416
tmp_type_out = :(SVector{$(sa[1]), T})
413-
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(TSize(a), TSize($(sb[1])), a,
414-
$(Expr(:call, tmp_type_in, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))::$tmp_type_out)
415-
for k2 = 1:sb[2]]
416417

417-
exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
418+
retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
419+
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()),
420+
a, $(Expr(:call, tmp_type_in, [uplo_access(sb, :b, i, k2, access_b) for i = 1:sb[1]]...)), $(Val(access_a)))::$tmp_type_out) for k2 = 1:sb[2]]
421+
422+
exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
423+
424+
return quote
425+
@inbounds $(Expr(:block, vect_exprs...))
426+
$(Expr(:block,
427+
:(@inbounds return (mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...)))))
428+
))
429+
end
430+
end
418431

419432
return quote
420433
@_inline_meta
421-
T = promote_op(matprod,Ta,Tb)
422-
$(Expr(:block,
423-
vect_exprs...,
424-
:(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
425-
))
434+
T = promote_op(matprod, Ta, Tb)
435+
a = mul_parent(wrapped_a)
436+
b = mul_parent(wrapped_b)
437+
$retexpr
426438
end
427439
end
428440

src/matrix_multiply_add.jl

Lines changed: 80 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
# import LinearAlgebra.MulAddMul
22

3-
abstract type MulAddMul{T} end
3+
abstract type MulAddMul{TA,TB} end
44

5-
struct AlphaBeta{T} <: MulAddMul{T}
6-
α::T
7-
β::T
8-
function AlphaBeta{T}(α,β) where T <: Real
9-
new{T}(α,β)
10-
end
5+
struct AlphaBeta{TA,TB} <: MulAddMul{TA,TB}
6+
α::TA
7+
β::TB
118
end
12-
@inline AlphaBeta::A::B) where {A,B} = AlphaBeta{promote_type(A,B)}(α,β)
139
@inline alpha(ab::AlphaBeta) = ab.α
1410
@inline beta(ab::AlphaBeta) = ab.β
1511

16-
struct NoMulAdd{T} <: MulAddMul{T} end
17-
@inline alpha(ma::NoMulAdd{T}) where T = one(T)
18-
@inline beta(ma::NoMulAdd{T}) where T = zero(T)
12+
struct NoMulAdd{TA,TB} <: MulAddMul{TA,TB} end
13+
@inline alpha(ma::NoMulAdd{TA,TB}) where {TA,TB} = one(TA)
14+
@inline beta(ma::NoMulAdd{TA,TB}) where {TA,TB} = zero(TB)
1915

2016
"""
2117
StaticMatMulLike
@@ -63,12 +59,14 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
6359
# 5-argument matrix multiplication
6460
# To avoid allocations, strip away Transpose type and store tranpose info in Size
6561
@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike, B::StaticVecOrMatLike,
66-
α::Real, β::Real) = _mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B),
62+
α::Real, β::Real) = _mul!(TSize(dest), mul_parent(dest), Size(A), Size(B), A, B,
6763
AlphaBeta(α,β))
6864

69-
@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike{T},
70-
B::StaticVecOrMatLike{T}) where T =
71-
_mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B), NoMulAdd{T}())
65+
@inline function LinearAlgebra.mul!(dest::StaticVecOrMatLike{TDest}, A::StaticVecOrMatLike{TA},
66+
B::StaticVecOrMatLike{TB}) where {TDest,TA,TB}
67+
TMul = typeof(one(TA)*one(TB)+one(TA)*one(TB))
68+
return _mul!(TSize(dest), mul_parent(dest), Size(A), Size(B), A, B, NoMulAdd{TMul, TDest}())
69+
end
7270

7371

7472
"Calculate the product of the dimensions being multiplied. Useful as a heuristic for unrolling."
@@ -112,55 +110,58 @@ end
112110
end
113111

114112
"Obtain an expression for the linear index of var[k,j], taking transposes into account"
115-
@inline _lind(A::Type{<:TSize}, k::Int, j::Int) = _lind(:a, A, k, j)
116113
function _lind(var::Symbol, A::Type{TSize{sa,tA}}, k::Int, j::Int) where {sa,tA}
117114
return uplo_access(sa, var, k, j, tA)
118115
end
119116

120117

121118

122119
# Matrix-vector multiplication
123-
@generated function _mul!(Sc::TSize{sc}, c::StaticVecOrMatLike, Sa::TSize{sa}, Sb::TSize{sb},
124-
a::StaticMatrix, b::StaticVector, _add::MulAddMul,
125-
::Val{col}=Val(1)) where {sa, sb, sc, col}
120+
@generated function _mul!(Sc::TSize{sc}, c::StaticVecOrMatLike, Sa::Size{sa}, Sb::Size{sb},
121+
wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}, _add::MulAddMul,
122+
::Val{col}=Val(1)) where {sa, sb, sc, col, Ta, Tb}
126123
if sa[2] != sb[1] || sc[1] != sa[1]
127124
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
128125
end
129126

130127
if sa[2] != 0
131-
lhs = [:($(_lind(:c,Sc,k,col))) for k = 1:sa[1]]
132-
ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
133-
[:($(_lind(Sa,k,j))*b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]]
134-
exprs = _muladd_expr(lhs, ab, _add)
128+
assign_expr = gen_by_access(wrapped_a) do access_a
129+
lhs = [:($(_lind(:c,Sc,k,col))) for k = 1:sa[1]]
130+
ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
131+
[:($(uplo_access(sa, :a, k, j, access_a)) * b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]]
132+
exprs = _muladd_expr(lhs, ab, _add)
133+
134+
return :(@inbounds $(Expr(:block, exprs...)))
135+
end
135136
else
136137
exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sa[1]]
138+
assign_expr = :(@inbounds $(Expr(:block, exprs...)))
137139
end
138140

139141
return quote
140142
# @_inline_meta
141-
# α = _add.alpha
142-
# β = _add.beta
143143
α = alpha(_add)
144144
β = beta(_add)
145-
@inbounds $(Expr(:block, exprs...))
145+
a = mul_parent(wrapped_a)
146+
$assign_expr
146147
return c
147148
end
148149
end
149150

150151
# Outer product
151-
@generated function _mul!(::TSize{sc}, c::StaticMatrix, ::TSize{sa,:any}, tsb::Union{TSize{sb,:transpose},TSize{sb,:adjoint}},
152-
a::StaticVector, b::StaticVector, _add::MulAddMul) where {sa, sb, sc}
152+
@generated function _mul!(::TSize{sc}, c::StaticMatrix, tsa::Size{sa}, tsb::Size{sb},
153+
a::StaticVector, b::Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}}, _add::MulAddMul) where {sa, sb, sc}
153154
if sc[1] != sa[1] || sc[2] != sb[2]
154155
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
155156
end
156157

157-
conjugate_b = isa(tsb, TSize{sb,:adjoint})
158+
conjugate_b = b <: Adjoint
158159

159160
lhs = [:(c[$(LinearIndices(sc)[i,j])]) for i = 1:sa[1], j = 1:sb[2]]
160161
if conjugate_b
161162
ab = [:(a[$i] * adjoint(b[$j])) for i = 1:sa[1], j = 1:sb[2]]
162163
else
163-
ab = [:(a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]]
164+
ab = [:(a[$i] * transpose(b[$j])) for i = 1:sa[1], j = 1:sb[2]]
164165
end
165166

166167
exprs = _muladd_expr(lhs, ab, _add)
@@ -175,9 +176,9 @@ end
175176
end
176177

177178
# Matrix-matrix multiplication
178-
@generated function _mul!(Sc::TSize{sc}, c::StaticMatrixLike,
179-
Sa::TSize{sa}, Sb::TSize{sb},
180-
a::StaticMatrixLike, b::StaticMatrixLike,
179+
@generated function _mul!(Sc::TSize{sc}, c::StaticMatMulLike,
180+
Sa::Size{sa}, Sb::Size{sb},
181+
a::StaticMatMulLike, b::StaticMatMulLike,
181182
_add::MulAddMul) where {sa, sb, sc}
182183
Ta,Tb,Tc = eltype(a), eltype(b), eltype(c)
183184
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
@@ -199,7 +200,7 @@ end
199200
if can_blas
200201
return quote
201202
@_inline_meta
202-
mul_blas!(Sc, c, Sa, Sb, a, b, _add)
203+
mul_blas!(Sc, c, TSize(a), TSize(b), mul_parent(a), mul_parent(b), _add)
203204
return c
204205
end
205206
else
@@ -213,18 +214,27 @@ end
213214
end
214215

215216

216-
@generated function muladd_unrolled_all!(Sc::TSize{sc}, c::StaticMatrixLike, Sa::TSize{sa}, Sb::TSize{sb},
217-
a::StaticMatrixLike, b::StaticMatrixLike, _add::MulAddMul) where {sa, sb, sc}
217+
@generated function muladd_unrolled_all!(Sc::TSize{sc}, wrapped_c::StaticMatMulLike, Sa::Size{sa}, Sb::Size{sb},
218+
wrapped_a::StaticMatMulLike{<:Any,<:Any,Ta}, wrapped_b::StaticMatMulLike{<:Any,<:Any,Tb}, _add::MulAddMul) where {sa, sb, sc, Ta, Tb}
218219
if !check_dims(Size(sc),Size(sa),Size(sb))
219220
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
220221
end
221222

222223
if sa[2] != 0
223224
lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]]
224-
ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
225-
[:($(_lind(:a, Sa, k1, j)) * $(_lind(:b, Sb, j, k2))) for j = 1:sa[2]]
226-
))) for k1 = 1:sa[1], k2 = 1:sb[2]]
227-
exprs = _muladd_expr(lhs, ab, _add)
225+
226+
assign_expr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
227+
228+
ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
229+
[:($(uplo_access(sa, :a, k1, j, access_a)) * $(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
230+
))) for k1 = 1:sa[1], k2 = 1:sb[2]]
231+
232+
exprs = _muladd_expr(lhs, ab, _add)
233+
return :(@inbounds $(Expr(:block, exprs...)))
234+
end
235+
else
236+
exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sc[1]*sc[2]]
237+
assign_expr = :(@inbounds $(Expr(:block, exprs...)))
228238
end
229239

230240
return quote
@@ -233,49 +243,63 @@ end
233243
# β = _add.beta
234244
α = alpha(_add)
235245
β = beta(_add)
236-
@inbounds $(Expr(:block, exprs...))
246+
c = mul_parent(wrapped_c)
247+
a = mul_parent(wrapped_a)
248+
b = mul_parent(wrapped_b)
249+
$assign_expr
250+
return c
237251
end
238252
end
239253

240254

241-
@generated function muladd_unrolled_chunks!(Sc::TSize{sc}, c::StaticMatrix, ::TSize{sa,tA}, Sb::TSize{sb,tB},
242-
a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) where {sa, sb, sc, tA, tB}
255+
@generated function muladd_unrolled_chunks!(Sc::TSize{sc}, wrapped_c::StaticMatMulLike, ::Size{sa}, Sb::Size{sb},
256+
wrapped_a::StaticMatMulLike{<:Any,<:Any,Ta}, wrapped_b::StaticMatMulLike{<:Any,<:Any,Tb}, _add::MulAddMul) where {sa, sb, sc, Ta, Tb}
243257
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
244258
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
245259
end
246260

261+
# This will not work for Symmetric and Hermitian wrappers of c
262+
lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]]
263+
247264
#vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
248265

249266
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
250-
tmp_type = SVector{sb[1], eltype(c)}
251-
vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(TSize{sa,tA}()), $(TSize{(sb[1],),tB}()),
252-
a, $(Expr(:call, tmp_type, [:($(_lind(:b, Sb, i, k2))) for i = 1:sb[1]]...)))) for k2 = 1:sb[2]]
267+
tmp_type = SVector{sb[1], eltype(wrapped_c)}
253268

254-
lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]]
255-
# exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
256-
rhs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
257-
exprs = _muladd_expr(lhs, rhs, _add)
269+
assign_expr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
270+
vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()),
271+
a, $(Expr(:call, tmp_type, [uplo_access(sb, :b, i, k2, access_b) for i = 1:sb[1]]...)), $(Val(access_a)))) for k2 = 1:sb[2]]
272+
273+
# exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
274+
rhs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
275+
exprs = _muladd_expr(lhs, rhs, _add)
258276

277+
return quote
278+
@inbounds $(Expr(:block, vect_exprs...))
279+
@inbounds $(Expr(:block, exprs...))
280+
end
281+
end
282+
259283
return quote
260284
@_inline_meta
261-
# α = _add.alpha
262-
# β = _add.beta
263285
α = alpha(_add)
264286
β = beta(_add)
265-
@inbounds $(Expr(:block, vect_exprs...))
266-
@inbounds $(Expr(:block, exprs...))
287+
c = mul_parent(wrapped_c)
288+
a = mul_parent(wrapped_a)
289+
b = mul_parent(wrapped_b)
290+
$assign_expr
267291
end
268292
end
269293

270294
# @inline partly_unrolled_multiply(Sa::Size, Sb::Size, a::StaticMatrix, b::StaticArray) where {sa, sb, Ta, Tb} =
271295
# partly_unrolled_multiply(TSize(Sa), TSize(Sb), a, b)
272-
@generated function partly_unrolled_multiply(Sa::TSize{sa}, ::TSize{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}) where {sa, sb, Ta, Tb}
296+
@generated function partly_unrolled_multiply(Sa::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}, ::Val{access_a}) where {sa, sb, Ta, Tb, access_a}
273297
if sa[2] != sb[1]
274298
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
275299
end
276300

277301
if sa[2] != 0
278-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(_lind(:a,Sa,k,j))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
302+
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
279303
else
280304
exprs = [:(zero(promote_op(matprod,Ta,Tb))) for k = 1:sa[1]]
281305
end

test/matrix_multiply.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ mul_wrappers = [
160160
for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
161161
wm = wrapper_m(mm)
162162
wn = wrapper_n(nn)
163-
if length(mm) >= 100 && (!isa(wm, StaticArray) || !isa(wn, StaticArray))
163+
if length(mm) >= 255 && (!isa(wm, StaticArray) || !isa(wn, StaticArray))
164164
continue
165165
end
166166
res_structure = StaticArrays.mul_result_structure(wm, wn)
@@ -340,10 +340,10 @@ mul_wrappers = [
340340
@test a::MMatrix{2,2,Int,4} == @MMatrix [8 14; 18 32]
341341
mul!(a, transpose(m), transpose(n))
342342
@test a::MMatrix{2,2,Int,4} == @MMatrix [11 19; 16 28]
343-
#=for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
343+
for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
344344
mul!(a, wrapper_m(m), wrapper_n(n))
345345
@test a::MMatrix{2,2,Int,4} == wrapper_m(Array(m))*wrapper_n(Array(n))
346-
end=#
346+
end
347347

348348
a2 = MArray{Tuple{2,2},Int,2,4}(undef)
349349
mul!(a2, m, n)

0 commit comments

Comments
 (0)