@@ -15,150 +15,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
15
15
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
16
16
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
17
17
18
- """
19
- gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a)
20
-
21
- Statically generate outer code for fully unrolled multiplication loops.
22
- Returned code does wrapper-specific tests (for example if a symmetric matrix view is
23
- `U` or `L`) and the body of the if expression is then generated by function `expr_gen`.
24
- The function `expr_gen` receives access pattern description symbol as its argument
25
- and this symbol is then consumed by uplo_access to generate the right code for matrix
26
- element access.
27
-
28
- The name of the matrix to test is indicated by `asym`.
29
- """
30
- function gen_by_access (expr_gen, a:: Type{<:StaticVecOrMat} , asym = :wrapped_a )
31
- return expr_gen (:any )
32
- end
33
- function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
34
- return quote
35
- if $ (asym). uplo == ' U'
36
- $ (expr_gen (:up ))
37
- else
38
- $ (expr_gen (:lo ))
39
- end
40
- end
41
- end
42
- function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
43
- return quote
44
- if $ (asym). uplo == ' U'
45
- $ (expr_gen (:up_herm ))
46
- else
47
- $ (expr_gen (:lo_herm ))
48
- end
49
- end
50
- end
51
- function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
52
- return expr_gen (:upper_triangular )
53
- end
54
- function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
55
- return expr_gen (:lower_triangular )
56
- end
57
- function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
58
- return expr_gen (:unit_upper_triangular )
59
- end
60
- function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
61
- return expr_gen (:unit_lower_triangular )
62
- end
63
- function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
64
- return expr_gen (:transpose )
65
- end
66
- function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
67
- return expr_gen (:adjoint )
68
- end
69
- function gen_by_access (expr_gen, a:: Type{<:SDiagonal} , asym = :wrapped_a )
70
- return expr_gen (:diagonal )
71
- end
72
- """
73
- gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
74
-
75
- Simiar to gen_by_access with only one type argument. The difference is that tests for both
76
- arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments,
77
- first for matrix `a` and the second for matrix `b`.
78
- """
79
- function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
80
- return quote
81
- return $ (gen_by_access (b, :wrapped_b ) do access_b
82
- expr_gen (:any , access_b)
83
- end )
84
- end
85
- end
86
- function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , b:: Type )
87
- return quote
88
- if wrapped_a. uplo == ' U'
89
- return $ (gen_by_access (b, :wrapped_b ) do access_b
90
- expr_gen (:up , access_b)
91
- end )
92
- else
93
- return $ (gen_by_access (b, :wrapped_b ) do access_b
94
- expr_gen (:lo , access_b)
95
- end )
96
- end
97
- end
98
- end
99
- function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , b:: Type )
100
- return quote
101
- if wrapped_a. uplo == ' U'
102
- return $ (gen_by_access (b, :wrapped_b ) do access_b
103
- expr_gen (:up_herm , access_b)
104
- end )
105
- else
106
- return $ (gen_by_access (b, :wrapped_b ) do access_b
107
- expr_gen (:lo_herm , access_b)
108
- end )
109
- end
110
- end
111
- end
112
- function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
113
- return quote
114
- return $ (gen_by_access (b, :wrapped_b ) do access_b
115
- expr_gen (:upper_triangular , access_b)
116
- end )
117
- end
118
- end
119
- function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
120
- return quote
121
- return $ (gen_by_access (b, :wrapped_b ) do access_b
122
- expr_gen (:lower_triangular , access_b)
123
- end )
124
- end
125
- end
126
- function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
127
- return quote
128
- return $ (gen_by_access (b, :wrapped_b ) do access_b
129
- expr_gen (:unit_upper_triangular , access_b)
130
- end )
131
- end
132
- end
133
- function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
134
- return quote
135
- return $ (gen_by_access (b, :wrapped_b ) do access_b
136
- expr_gen (:unit_lower_triangular , access_b)
137
- end )
138
- end
139
- end
140
- function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , b:: Type )
141
- return quote
142
- return $ (gen_by_access (b, :wrapped_b ) do access_b
143
- expr_gen (:transpose , access_b)
144
- end )
145
- end
146
- end
147
- function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , b:: Type )
148
- return quote
149
- return $ (gen_by_access (b, :wrapped_b ) do access_b
150
- expr_gen (:adjoint , access_b)
151
- end )
152
- end
153
- end
154
- function gen_by_access (expr_gen, a:: Type{<:SDiagonal} , b:: Type )
155
- return quote
156
- return $ (gen_by_access (b, :wrapped_b ) do access_b
157
- expr_gen (:diagonal , access_b)
158
- end )
159
- end
160
- end
161
-
162
18
"""
163
19
mul_result_structure(a::Type, b::Type)
164
20
@@ -202,99 +58,6 @@ function mul_result_structure(::SDiagonal, ::SDiagonal)
202
58
return Diagonal
203
59
end
204
60
205
- """
206
- uplo_access(sa, asym, k, j, uplo)
207
-
208
- Generate code for matrix element access, for a matrix of size `sa` locally referred to
209
- as `asym` in the context where the result will be used. Both indices `k` and `j` need to be
210
- statically known for this function to work. `uplo` is the access pattern mode generated
211
- by the `gen_by_access` function.
212
- """
213
- function uplo_access (sa, asym, k, j, uplo)
214
- TAsym = Symbol (" T" * string (asym))
215
- if uplo == :any
216
- return :($ asym[$ (LinearIndices (sa)[k, j])])
217
- elseif uplo == :up
218
- if k < j
219
- return :($ asym[$ (LinearIndices (sa)[k, j])])
220
- elseif k == j
221
- return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
222
- else
223
- return :(transpose ($ asym[$ (LinearIndices (sa)[j, k])]))
224
- end
225
- elseif uplo == :lo
226
- if k > j
227
- return :($ asym[$ (LinearIndices (sa)[k, j])])
228
- elseif k == j
229
- return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
230
- else
231
- return :(transpose ($ asym[$ (LinearIndices (sa)[j, k])]))
232
- end
233
- elseif uplo == :up_herm
234
- if k < j
235
- return :($ asym[$ (LinearIndices (sa)[k, j])])
236
- elseif k == j
237
- return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
238
- else
239
- return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
240
- end
241
- elseif uplo == :lo_herm
242
- if k > j
243
- return :($ asym[$ (LinearIndices (sa)[k, j])])
244
- elseif k == j
245
- return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
246
- else
247
- return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
248
- end
249
- elseif uplo == :upper_triangular
250
- if k <= j
251
- return :($ asym[$ (LinearIndices (sa)[k, j])])
252
- else
253
- return :(zero ($ TAsym))
254
- end
255
- elseif uplo == :lower_triangular
256
- if k >= j
257
- return :($ asym[$ (LinearIndices (sa)[k, j])])
258
- else
259
- return :(zero ($ TAsym))
260
- end
261
- elseif uplo == :unit_upper_triangular
262
- if k < j
263
- return :($ asym[$ (LinearIndices (sa)[k, j])])
264
- elseif k == j
265
- return :(oneunit ($ TAsym))
266
- else
267
- return :(zero ($ TAsym))
268
- end
269
- elseif uplo == :unit_lower_triangular
270
- if k > j
271
- return :($ asym[$ (LinearIndices (sa)[k, j])])
272
- elseif k == j
273
- return :(oneunit ($ TAsym))
274
- else
275
- return :(zero ($ TAsym))
276
- end
277
- elseif uplo == :upper_hessenberg
278
- if k <= j+ 1
279
- return :($ asym[$ (LinearIndices (sa)[k, j])])
280
- else
281
- return :(zero ($ TAsym))
282
- end
283
- elseif uplo == :transpose
284
- return :(transpose ($ asym[$ (LinearIndices (reverse (sa))[j, k])]))
285
- elseif uplo == :adjoint
286
- return :(adjoint ($ asym[$ (LinearIndices (reverse (sa))[j, k])]))
287
- elseif uplo == :diagonal
288
- if k == j
289
- return :($ asym[$ k])
290
- else
291
- return :(zero ($ TAsym))
292
- end
293
- else
294
- error (" Unknown uplo: $uplo " )
295
- end
296
- end
297
-
298
61
# Implementations
299
62
300
63
function mul_smat_vec_exprs (sa, access_a)
@@ -369,31 +132,6 @@ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTria
369
132
@eval _unstatic_array (:: Type{$TWR{T,TSA}} ) where {S, T, N, TSA<: StaticArray{S,T,N} } = $ TWR{T,<: AbstractArray{T,N} }
370
133
end
371
134
372
- function combine_products (expr_list)
373
- filtered = filter (expr_list) do expr
374
- if expr. head != :call || expr. args[1 ] != :*
375
- error (" expected call to *" )
376
- end
377
- for arg in expr. args[2 : end ]
378
- if isa (arg, Expr) && arg. head == :call && arg. args[1 ] == :zero
379
- return false
380
- end
381
- end
382
- return true
383
- end
384
- if isempty (filtered)
385
- return :(zero (T))
386
- else
387
- return reduce (filtered) do ex1, ex2
388
- if ex2. head != :call || ex2. args[1 ] != :*
389
- error (" expected call to *" )
390
- end
391
-
392
- return :(muladd ($ (ex2. args[2 ]), $ (ex2. args[3 ]), $ ex1))
393
- end
394
- end
395
- end
396
-
397
135
@generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
398
136
S = Size (sa[1 ], sb[2 ])
399
137
# Heuristic choice for amount of codegen
0 commit comments