@@ -4,12 +4,20 @@ import LinearAlgebra: BlasFloat, matprod, mul!
4
4
# Manage dispatch of * and mul!
5
5
# TODO Adjoint? (Inner product?)
6
6
7
+ """
8
+ StaticMatMulLike
9
+
10
+ Static wrappers used for multiplication dispatch.
11
+ """
7
12
const StaticMatMulLike{s1, s2, T} = Union{
8
13
StaticMatrix{s1, s2, T},
9
14
Symmetric{T, <: StaticMatrix{s1, s2, T} },
10
15
Hermitian{T, <: StaticMatrix{s1, s2, T} },
11
16
LowerTriangular{T, <: StaticMatrix{s1, s2, T} },
12
- UpperTriangular{T, <: StaticMatrix{s1, s2, T} }}
17
+ UpperTriangular{T, <: StaticMatrix{s1, s2, T} },
18
+ Adjoint{T, <: StaticMatrix{s1, s2, T} },
19
+ Transpose{T, <: StaticMatrix{s1, s2, T} }}
20
+
13
21
14
22
@inline * (A:: StaticMatMulLike , B:: AbstractVector ) = _mul (Size (A), A, B)
15
23
@inline * (A:: StaticMatMulLike , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
@@ -20,6 +28,18 @@ const StaticMatMulLike{s1, s2, T} = Union{
20
28
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
21
29
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
22
30
31
+ """
32
+ gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :a)
33
+
34
+ Statically generate outer code for fully unrolled multiplication loops.
35
+ Returned code does wrapper-specific tests (for example if a symmetric matrix view is
36
+ `U` or `L`) and the body of the if expression is then generated by function `expr_gen`.
37
+ The function `expr_gen` receives access pattern description symbol as its argument
38
+ and this symbol is then consumed by uplo_access to generate the right code for matrix
39
+ element access.
40
+
41
+ The name of the matrix to test is indicated by `asym`.
42
+ """
23
43
function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , asym = :a )
24
44
return expr_gen (:any )
25
45
end
47
67
function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
48
68
return expr_gen (:lower_triangular )
49
69
end
70
+ function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , asym = :a )
71
+ return expr_gen (:transpose )
72
+ end
73
+ function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , asym = :a )
74
+ return expr_gen (:adjoint )
75
+ end
76
+ """
77
+ gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
78
+
79
+ Simiar to gen_by_access with only one type argument. The difference is that tests for both
80
+ arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments,
81
+ first for matrix `a` and the second for matrix `b`.
82
+ """
50
83
function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
51
84
return quote
52
85
return $ (gen_by_access (b, :b ) do access_b
@@ -94,6 +127,20 @@ function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix
94
127
end )
95
128
end
96
129
end
130
+ function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , b:: Type )
131
+ return quote
132
+ return $ (gen_by_access (b, :b ) do access_b
133
+ expr_gen (:transpose , access_b)
134
+ end )
135
+ end
136
+ end
137
+ function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , b:: Type )
138
+ return quote
139
+ return $ (gen_by_access (b, :b ) do access_b
140
+ expr_gen (:adjoint , access_b)
141
+ end )
142
+ end
143
+ end
97
144
98
145
"""
99
146
mul_result_structure(a::Type, b::Type)
@@ -111,6 +158,14 @@ function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerT
111
158
return LowerTriangular
112
159
end
113
160
161
+ """
162
+ uplo_access(sa, asym, k, j, uplo)
163
+
164
+ Generate code for matrix element access, for a matrix of size `sa` locally referred to
165
+ as `asym` in the context where the result will be used. Both indices `k` and `j` need to be
166
+ statically known for this function to work. `uplo` is the access pattern mode generated
167
+ by the `gen_by_access` function.
168
+ """
114
169
function uplo_access (sa, asym, k, j, uplo)
115
170
if uplo == :any
116
171
return :($ asym[$ (LinearIndices (sa)[k, j])])
@@ -150,6 +205,10 @@ function uplo_access(sa, asym, k, j, uplo)
150
205
else
151
206
return :(zero (T))
152
207
end
208
+ elseif uplo == :transpose
209
+ return :($ asym[$ (LinearIndices (sa)[j, k])])
210
+ elseif uplo == :ajoint
211
+ return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
153
212
end
154
213
end
155
214
@@ -216,23 +275,35 @@ end
216
275
end
217
276
end
218
277
278
+ _unstatic_array (:: Type{TSA} ) where {S, T, N, TSA<: StaticArray{S,T,N} } = AbstractArray{T,N}
279
+ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular]
280
+ @eval _unstatic_array (:: Type{$TWR{T,TSA}} ) where {S, T, N, TSA<: StaticArray{S,T,N} } = $ TWR{T,<: AbstractArray{T,N} }
281
+ end
282
+
219
283
@generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
220
284
# Heuristic choice for amount of codegen
221
285
if sa[1 ]* sa[2 ]* sb[2 ] <= 8 * 8 * 8 || ! (a <: StaticMatrix ) || ! (b <: StaticMatrix )
222
286
return quote
223
287
@_inline_meta
224
288
return mul_unrolled (Sa, Sb, a, b)
225
289
end
226
- elseif sa[1 ] <= 14 && sa[2 ] <= 14 && sb[2 ] <= 14
290
+ elseif a <: StaticMatrix && b <: StaticMatrix && sa[1 ] <= 14 && sa[2 ] <= 14 && sb[2 ] <= 14
227
291
return quote
228
292
@_inline_meta
229
293
return mul_unrolled_chunks (Sa, Sb, a, b)
230
294
end
231
- else
295
+ elseif a <: StaticMatrix && b <: StaticMatrix
232
296
return quote
233
297
@_inline_meta
234
298
return mul_loop (Sa, Sb, a, b)
235
299
end
300
+ else
301
+ # we don't have any special code for handling this case so let's fall back to
302
+ # the generic implementation of matrix multiplication
303
+ return quote
304
+ @_inline_meta
305
+ return invoke (* , Tuple{$ (_unstatic_array (a)),$ (_unstatic_array (b))}, a, b)
306
+ end
236
307
end
237
308
end
238
309
0 commit comments