1
1
# import LinearAlgebra.MulAddMul
2
2
3
- abstract type MulAddMul{T } end
3
+ abstract type MulAddMul{TA,TB } end
4
4
5
- struct AlphaBeta{T} <: MulAddMul{T}
6
- α:: T
7
- β:: T
8
- function AlphaBeta {T} (α,β) where T <: Real
9
- new {T} (α,β)
10
- end
5
+ struct AlphaBeta{TA,TB} <: MulAddMul{TA,TB}
6
+ α:: TA
7
+ β:: TB
11
8
end
12
- @inline AlphaBeta (α:: A ,β:: B ) where {A,B} = AlphaBeta {promote_type(A,B)} (α,β)
13
9
@inline alpha (ab:: AlphaBeta ) = ab. α
14
10
@inline beta (ab:: AlphaBeta ) = ab. β
15
11
16
- struct NoMulAdd{T } <: MulAddMul{T } end
17
- @inline alpha (ma:: NoMulAdd{T } ) where T = one (T )
18
- @inline beta (ma:: NoMulAdd{T } ) where T = zero (T )
12
+ struct NoMulAdd{TA,TB } <: MulAddMul{TA,TB } end
13
+ @inline alpha (ma:: NoMulAdd{TA,TB } ) where {TA,TB} = one (TA )
14
+ @inline beta (ma:: NoMulAdd{TA,TB } ) where {TA,TB} = zero (TB )
19
15
20
16
"""
21
17
StaticMatMulLike
@@ -63,12 +59,14 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
63
59
# 5-argument matrix multiplication
64
60
# To avoid allocations, strip away Transpose type and store tranpose info in Size
65
61
@inline LinearAlgebra. mul! (dest:: StaticVecOrMatLike , A:: StaticVecOrMatLike , B:: StaticVecOrMatLike ,
66
- α:: Real , β:: Real ) = _mul! (TSize (dest), mul_parent (dest), TSize (A), TSize (B), mul_parent (A), mul_parent (B) ,
62
+ α:: Real , β:: Real ) = _mul! (TSize (dest), mul_parent (dest), Size (A), Size (B), A, B ,
67
63
AlphaBeta (α,β))
68
64
69
- @inline LinearAlgebra. mul! (dest:: StaticVecOrMatLike , A:: StaticVecOrMatLike{T} ,
70
- B:: StaticVecOrMatLike{T} ) where T =
71
- _mul! (TSize (dest), mul_parent (dest), TSize (A), TSize (B), mul_parent (A), mul_parent (B), NoMulAdd {T} ())
65
+ @inline function LinearAlgebra. mul! (dest:: StaticVecOrMatLike{TDest} , A:: StaticVecOrMatLike{TA} ,
66
+ B:: StaticVecOrMatLike{TB} ) where {TDest,TA,TB}
67
+ TMul = typeof (one (TA)* one (TB)+ one (TA)* one (TB))
68
+ return _mul! (TSize (dest), mul_parent (dest), Size (A), Size (B), A, B, NoMulAdd {TMul, TDest} ())
69
+ end
72
70
73
71
74
72
" Calculate the product of the dimensions being multiplied. Useful as a heuristic for unrolling."
@@ -112,55 +110,58 @@ end
112
110
end
113
111
114
112
" Obtain an expression for the linear index of var[k,j], taking transposes into account"
115
- @inline _lind (A:: Type{<:TSize} , k:: Int , j:: Int ) = _lind (:a , A, k, j)
116
113
function _lind (var:: Symbol , A:: Type{TSize{sa,tA}} , k:: Int , j:: Int ) where {sa,tA}
117
114
return uplo_access (sa, var, k, j, tA)
118
115
end
119
116
120
117
121
118
122
119
# Matrix-vector multiplication
123
- @generated function _mul! (Sc:: TSize{sc} , c:: StaticVecOrMatLike , Sa:: TSize {sa} , Sb:: TSize {sb} ,
124
- a :: StaticMatrix , b:: StaticVector , _add:: MulAddMul ,
125
- :: Val{col} = Val (1 )) where {sa, sb, sc, col}
120
+ @generated function _mul! (Sc:: TSize{sc} , c:: StaticVecOrMatLike , Sa:: Size {sa} , Sb:: Size {sb} ,
121
+ wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} , _add:: MulAddMul ,
122
+ :: Val{col} = Val (1 )) where {sa, sb, sc, col, Ta, Tb }
126
123
if sa[2 ] != sb[1 ] || sc[1 ] != sa[1 ]
127
124
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
128
125
end
129
126
130
127
if sa[2 ] != 0
131
- lhs = [:($ (_lind (:c ,Sc,k,col))) for k = 1 : sa[1 ]]
132
- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
133
- [:($ (_lind (Sa,k,j))* b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
134
- exprs = _muladd_expr (lhs, ab, _add)
128
+ assign_expr = gen_by_access (wrapped_a) do access_a
129
+ lhs = [:($ (_lind (:c ,Sc,k,col))) for k = 1 : sa[1 ]]
130
+ ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
131
+ [:($ (uplo_access (sa, :a , k, j, access_a)) * b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
132
+ exprs = _muladd_expr (lhs, ab, _add)
133
+
134
+ return :(@inbounds $ (Expr (:block , exprs... )))
135
+ end
135
136
else
136
137
exprs = [:(c[$ k] = zero (eltype (c))) for k = 1 : sa[1 ]]
138
+ assign_expr = :(@inbounds $ (Expr (:block , exprs... )))
137
139
end
138
140
139
141
return quote
140
142
# @_inline_meta
141
- # α = _add.alpha
142
- # β = _add.beta
143
143
α = alpha (_add)
144
144
β = beta (_add)
145
- @inbounds $ (Expr (:block , exprs... ))
145
+ a = mul_parent (wrapped_a)
146
+ $ assign_expr
146
147
return c
147
148
end
148
149
end
149
150
150
151
# Outer product
151
- @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , :: TSize {sa,:any } , tsb:: Union{TSize{sb,:transpose},TSize{sb,:adjoint} } ,
152
- a:: StaticVector , b:: StaticVector , _add:: MulAddMul ) where {sa, sb, sc}
152
+ @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , tsa :: Size {sa} , tsb:: Size{sb } ,
153
+ a:: StaticVector , b:: Union{Transpose{<:Any, <: StaticVector}, Adjoint{<:Any, <:StaticVector}} , _add:: MulAddMul ) where {sa, sb, sc}
153
154
if sc[1 ] != sa[1 ] || sc[2 ] != sb[2 ]
154
155
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
155
156
end
156
157
157
- conjugate_b = isa (tsb, TSize{sb, :adjoint })
158
+ conjugate_b = b <: Adjoint
158
159
159
160
lhs = [:(c[$ (LinearIndices (sc)[i,j])]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
160
161
if conjugate_b
161
162
ab = [:(a[$ i] * adjoint (b[$ j])) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
162
163
else
163
- ab = [:(a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
164
+ ab = [:(a[$ i] * transpose ( b[$ j]) ) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
164
165
end
165
166
166
167
exprs = _muladd_expr (lhs, ab, _add)
175
176
end
176
177
177
178
# Matrix-matrix multiplication
178
- @generated function _mul! (Sc:: TSize{sc} , c:: StaticMatrixLike ,
179
- Sa:: TSize {sa} , Sb:: TSize {sb} ,
180
- a:: StaticMatrixLike , b:: StaticMatrixLike ,
179
+ @generated function _mul! (Sc:: TSize{sc} , c:: StaticMatMulLike ,
180
+ Sa:: Size {sa} , Sb:: Size {sb} ,
181
+ a:: StaticMatMulLike , b:: StaticMatMulLike ,
181
182
_add:: MulAddMul ) where {sa, sb, sc}
182
183
Ta,Tb,Tc = eltype (a), eltype (b), eltype (c)
183
184
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
199
200
if can_blas
200
201
return quote
201
202
@_inline_meta
202
- mul_blas! (Sc, c, Sa, Sb, a, b , _add)
203
+ mul_blas! (Sc, c, TSize (a), TSize (b), mul_parent (a), mul_parent (b) , _add)
203
204
return c
204
205
end
205
206
else
@@ -213,18 +214,27 @@ end
213
214
end
214
215
215
216
216
- @generated function muladd_unrolled_all! (Sc:: TSize{sc} , c :: StaticMatrixLike , Sa:: TSize {sa} , Sb:: TSize {sb} ,
217
- a :: StaticMatrixLike , b :: StaticMatrixLike , _add:: MulAddMul ) where {sa, sb, sc}
217
+ @generated function muladd_unrolled_all! (Sc:: TSize{sc} , wrapped_c :: StaticMatMulLike , Sa:: Size {sa} , Sb:: Size {sb} ,
218
+ wrapped_a :: StaticMatMulLike{<:Any,<:Any,Ta} , wrapped_b :: StaticMatMulLike{<:Any,<:Any,Tb} , _add:: MulAddMul ) where {sa, sb, sc, Ta, Tb }
218
219
if ! check_dims (Size (sc),Size (sa),Size (sb))
219
220
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
220
221
end
221
222
222
223
if sa[2 ] != 0
223
224
lhs = [:($ (_lind (:c , Sc, k1, k2))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
224
- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
225
- [:($ (_lind (:a , Sa, k1, j)) * $ (_lind (:b , Sb, j, k2))) for j = 1 : sa[2 ]]
226
- ))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
227
- exprs = _muladd_expr (lhs, ab, _add)
225
+
226
+ assign_expr = gen_by_access (wrapped_a, wrapped_b) do access_a, access_b
227
+
228
+ ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
229
+ [:($ (uplo_access (sa, :a , k1, j, access_a)) * $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
230
+ ))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
231
+
232
+ exprs = _muladd_expr (lhs, ab, _add)
233
+ return :(@inbounds $ (Expr (:block , exprs... )))
234
+ end
235
+ else
236
+ exprs = [:(c[$ k] = zero (eltype (c))) for k = 1 : sc[1 ]* sc[2 ]]
237
+ assign_expr = :(@inbounds $ (Expr (:block , exprs... )))
228
238
end
229
239
230
240
return quote
@@ -233,49 +243,63 @@ end
233
243
# β = _add.beta
234
244
α = alpha (_add)
235
245
β = beta (_add)
236
- @inbounds $ (Expr (:block , exprs... ))
246
+ c = mul_parent (wrapped_c)
247
+ a = mul_parent (wrapped_a)
248
+ b = mul_parent (wrapped_b)
249
+ $ assign_expr
250
+ return c
237
251
end
238
252
end
239
253
240
254
241
- @generated function muladd_unrolled_chunks! (Sc:: TSize{sc} , c :: StaticMatrix , :: TSize {sa,tA } , Sb:: TSize {sb,tB } ,
242
- a :: StaticMatrix , b :: StaticMatrix , _add:: MulAddMul ) where {sa, sb, sc, tA, tB }
255
+ @generated function muladd_unrolled_chunks! (Sc:: TSize{sc} , wrapped_c :: StaticMatMulLike , :: Size {sa} , Sb:: Size {sb} ,
256
+ wrapped_a :: StaticMatMulLike{<:Any,<:Any,Ta} , wrapped_b :: StaticMatMulLike{<:Any,<:Any,Tb} , _add:: MulAddMul ) where {sa, sb, sc, Ta, Tb }
243
257
if sb[1 ] != sa[2 ] || sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
244
258
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
245
259
end
246
260
261
+ # This will not work for Symmetric and Hermitian wrappers of c
262
+ lhs = [:($ (_lind (:c , Sc, k1, k2))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
263
+
247
264
# vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
248
265
249
266
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
250
- tmp_type = SVector{sb[1 ], eltype (c)}
251
- vect_exprs = [:($ (Symbol (" tmp_$k2 " )) = partly_unrolled_multiply ($ (TSize {sa,tA} ()), $ (TSize {(sb[1],),tB} ()),
252
- a, $ (Expr (:call , tmp_type, [:($ (_lind (:b , Sb, i, k2))) for i = 1 : sb[1 ]]. .. )))) for k2 = 1 : sb[2 ]]
267
+ tmp_type = SVector{sb[1 ], eltype (wrapped_c)}
253
268
254
- lhs = [:($ (_lind (:c , Sc, k1, k2))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
255
- # exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
256
- rhs = [:($ (Symbol (" tmp_$k2 " ))[$ k1]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
257
- exprs = _muladd_expr (lhs, rhs, _add)
269
+ assign_expr = gen_by_access (wrapped_a, wrapped_b) do access_a, access_b
270
+ vect_exprs = [:($ (Symbol (" tmp_$k2 " )) = partly_unrolled_multiply ($ (Size {sa} ()), $ (Size {(sb[1],)} ()),
271
+ a, $ (Expr (:call , tmp_type, [uplo_access (sb, :b , i, k2, access_b) for i = 1 : sb[1 ]]. .. )), $ (Val (access_a)))) for k2 = 1 : sb[2 ]]
272
+
273
+ # exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
274
+ rhs = [:($ (Symbol (" tmp_$k2 " ))[$ k1]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
275
+ exprs = _muladd_expr (lhs, rhs, _add)
258
276
277
+ return quote
278
+ @inbounds $ (Expr (:block , vect_exprs... ))
279
+ @inbounds $ (Expr (:block , exprs... ))
280
+ end
281
+ end
282
+
259
283
return quote
260
284
@_inline_meta
261
- # α = _add.alpha
262
- # β = _add.beta
263
285
α = alpha (_add)
264
286
β = beta (_add)
265
- @inbounds $ (Expr (:block , vect_exprs... ))
266
- @inbounds $ (Expr (:block , exprs... ))
287
+ c = mul_parent (wrapped_c)
288
+ a = mul_parent (wrapped_a)
289
+ b = mul_parent (wrapped_b)
290
+ $ assign_expr
267
291
end
268
292
end
269
293
270
294
# @inline partly_unrolled_multiply(Sa::Size, Sb::Size, a::StaticMatrix, b::StaticArray) where {sa, sb, Ta, Tb} =
271
295
# partly_unrolled_multiply(TSize(Sa), TSize(Sb), a, b)
272
- @generated function partly_unrolled_multiply (Sa:: TSize {sa} , :: TSize {sb} , a:: StaticMatrix {<:Any, <:Any, Ta} , b:: StaticArray{<:Tuple, Tb} ) where {sa, sb, Ta, Tb}
296
+ @generated function partly_unrolled_multiply (Sa:: Size {sa} , :: Size {sb} , a:: StaticMatMulLike {<:Any, <:Any, Ta} , b:: StaticArray{<:Tuple, Tb} , :: Val{access_a} ) where {sa, sb, Ta, Tb, access_a }
273
297
if sa[2 ] != sb[1 ]
274
298
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
275
299
end
276
300
277
301
if sa[2 ] != 0
278
- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (_lind ( :a ,Sa,k,j ))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
302
+ exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a ))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
279
303
else
280
304
exprs = [:(zero (promote_op (matprod,Ta,Tb))) for k = 1 : sa[1 ]]
281
305
end
0 commit comments