@@ -4,26 +4,10 @@ import LinearAlgebra: BlasFloat, matprod, mul!
4
4
# Manage dispatch of * and mul!
5
5
# TODO Adjoint? (Inner product?)
6
6
7
- """
8
- StaticMatMulLike
9
-
10
- Static wrappers used for multiplication dispatch.
11
- """
12
- const StaticMatMulLike{s1, s2, T} = Union{
13
- StaticMatrix{s1, s2, T},
14
- Symmetric{T, <: StaticMatrix{s1, s2, T} },
15
- Hermitian{T, <: StaticMatrix{s1, s2, T} },
16
- LowerTriangular{T, <: StaticMatrix{s1, s2, T} },
17
- UpperTriangular{T, <: StaticMatrix{s1, s2, T} },
18
- UnitLowerTriangular{T, <: StaticMatrix{s1, s2, T} },
19
- UnitUpperTriangular{T, <: StaticMatrix{s1, s2, T} },
20
- UpperHessenberg{T, <: StaticMatrix{s1, s2, T} },
21
- Adjoint{T, <: StaticMatrix{s1, s2, T} },
22
- Transpose{T, <: StaticMatrix{s1, s2, T} }}
23
-
24
-
25
- @inline * (A:: StaticMatMulLike , B:: AbstractVector ) = _mul (Size (A), A, B)
7
+ # *(A::StaticMatMulLike, B::AbstractVector) causes an ambiguity with SparseArrays
8
+ @inline * (A:: StaticMatrix , B:: AbstractVector ) = _mul (Size (A), A, B)
26
9
@inline * (A:: StaticMatMulLike , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
10
+ @inline * (A:: StaticMatrix , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
27
11
@inline * (A:: StaticMatMulLike , B:: StaticMatMulLike ) = _mul (Size (A), Size (B), A, B)
28
12
@inline * (A:: StaticVector , B:: StaticMatMulLike ) = * (reshape (A, Size (Size (A)[1 ], 1 )), B)
29
13
@inline * (A:: StaticVector , B:: Transpose{<:Any, <:StaticVector} ) = _mul (Size (A), Size (B), A, B)
@@ -32,7 +16,7 @@ const StaticMatMulLike{s1, s2, T} = Union{
32
16
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
33
17
34
18
"""
35
- gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :a )
19
+ gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a )
36
20
37
21
Statically generate outer code for fully unrolled multiplication loops.
38
22
Returned code does wrapper-specific tests (for example if a symmetric matrix view is
@@ -43,10 +27,10 @@ element access.
43
27
44
28
The name of the matrix to test is indicated by `asym`.
45
29
"""
46
- function gen_by_access (expr_gen, a:: Type{<:StaticVecOrMat} , asym = :a )
30
+ function gen_by_access (expr_gen, a:: Type{<:StaticVecOrMat} , asym = :wrapped_a )
47
31
return expr_gen (:any )
48
32
end
49
- function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :a )
33
+ function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
50
34
return quote
51
35
if $ (asym). uplo == ' U'
52
36
$ (expr_gen (:up ))
@@ -55,7 +39,7 @@ function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, as
55
39
end
56
40
end
57
41
end
58
- function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :a )
42
+ function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
59
43
return quote
60
44
if $ (asym). uplo == ' U'
61
45
$ (expr_gen (:up_herm ))
@@ -64,25 +48,22 @@ function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, as
64
48
end
65
49
end
66
50
end
67
- function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :a )
51
+ function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
68
52
return expr_gen (:upper_triangular )
69
53
end
70
- function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
54
+ function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
71
55
return expr_gen (:lower_triangular )
72
56
end
73
- function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , asym = :a )
57
+ function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
74
58
return expr_gen (:unit_upper_triangular )
75
59
end
76
- function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
60
+ function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
77
61
return expr_gen (:unit_lower_triangular )
78
62
end
79
- function gen_by_access (expr_gen, a:: Type{<:UpperHessenberg{<:Any, <:StaticMatrix}} , asym = :a )
80
- return expr_gen (:upper_hessenberg )
81
- end
82
- function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticVecOrMat}} , asym = :a )
63
+ function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
83
64
return expr_gen (:transpose )
84
65
end
85
- function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticVecOrMat}} , asym = :a )
66
+ function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
86
67
return expr_gen (:adjoint )
87
68
end
88
69
"""
@@ -94,82 +75,75 @@ first for matrix `a` and the second for matrix `b`.
94
75
"""
95
76
function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
96
77
return quote
97
- return $ (gen_by_access (b, :b ) do access_b
78
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
98
79
expr_gen (:any , access_b)
99
80
end )
100
81
end
101
82
end
102
83
function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , b:: Type )
103
84
return quote
104
- if a . uplo == ' U'
105
- return $ (gen_by_access (b, :b ) do access_b
85
+ if wrapped_a . uplo == ' U'
86
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
106
87
expr_gen (:up , access_b)
107
88
end )
108
89
else
109
- return $ (gen_by_access (b, :b ) do access_b
90
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
110
91
expr_gen (:lo , access_b)
111
92
end )
112
93
end
113
94
end
114
95
end
115
96
function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , b:: Type )
116
97
return quote
117
- if a . uplo == ' U'
118
- return $ (gen_by_access (b, :b ) do access_b
98
+ if wrapped_a . uplo == ' U'
99
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
119
100
expr_gen (:up_herm , access_b)
120
101
end )
121
102
else
122
- return $ (gen_by_access (b, :b ) do access_b
103
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
123
104
expr_gen (:lo_herm , access_b)
124
105
end )
125
106
end
126
107
end
127
108
end
128
109
function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
129
110
return quote
130
- return $ (gen_by_access (b, :b ) do access_b
111
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
131
112
expr_gen (:upper_triangular , access_b)
132
113
end )
133
114
end
134
115
end
135
116
function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
136
117
return quote
137
- return $ (gen_by_access (b, :b ) do access_b
118
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
138
119
expr_gen (:lower_triangular , access_b)
139
120
end )
140
121
end
141
122
end
142
123
function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
143
124
return quote
144
- return $ (gen_by_access (b, :b ) do access_b
125
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
145
126
expr_gen (:unit_upper_triangular , access_b)
146
127
end )
147
128
end
148
129
end
149
130
function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
150
131
return quote
151
- return $ (gen_by_access (b, :b ) do access_b
132
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
152
133
expr_gen (:unit_lower_triangular , access_b)
153
134
end )
154
135
end
155
136
end
156
- function gen_by_access (expr_gen, a:: Type{<:UpperHessenberg{<:Any, <:StaticMatrix}} , b:: Type )
157
- return quote
158
- return $ (gen_by_access (b, :b ) do access_b
159
- expr_gen (:upper_hessenberg , access_b)
160
- end )
161
- end
162
- end
163
137
function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , b:: Type )
164
138
return quote
165
- return $ (gen_by_access (b, :b ) do access_b
139
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
166
140
expr_gen (:transpose , access_b)
167
141
end )
168
142
end
169
143
end
170
144
function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , b:: Type )
171
145
return quote
172
- return $ (gen_by_access (b, :b ) do access_b
146
+ return $ (gen_by_access (b, :wrapped_b ) do access_b
173
147
expr_gen (:adjoint , access_b)
174
148
end )
175
149
end
@@ -200,65 +174,74 @@ statically known for this function to work. `uplo` is the access pattern mode ge
200
174
by the `gen_by_access` function.
201
175
"""
202
176
function uplo_access (sa, asym, k, j, uplo)
177
+ TAsym = Symbol (" T" * string (asym))
203
178
if uplo == :any
204
179
return :($ asym[$ (LinearIndices (sa)[k, j])])
205
180
elseif uplo == :up
206
- if k <= j
181
+ if k < j
207
182
return :($ asym[$ (LinearIndices (sa)[k, j])])
183
+ elseif k == j
184
+ return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
208
185
else
209
- return :($ asym[$ (LinearIndices (sa)[j, k])])
186
+ return :(transpose ( $ asym[$ (LinearIndices (sa)[j, k])]) )
210
187
end
211
188
elseif uplo == :lo
212
- if k >= j
189
+ if k > j
213
190
return :($ asym[$ (LinearIndices (sa)[k, j])])
191
+ elseif k == j
192
+ return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
214
193
else
215
- return :($ asym[$ (LinearIndices (sa)[j, k])])
194
+ return :(transpose ( $ asym[$ (LinearIndices (sa)[j, k])]) )
216
195
end
217
196
elseif uplo == :up_herm
218
- if k <= j
197
+ if k < j
219
198
return :($ asym[$ (LinearIndices (sa)[k, j])])
199
+ elseif k == j
200
+ return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
220
201
else
221
202
return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
222
203
end
223
204
elseif uplo == :lo_herm
224
- if k >= j
205
+ if k > j
225
206
return :($ asym[$ (LinearIndices (sa)[k, j])])
207
+ elseif k == j
208
+ return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
226
209
else
227
210
return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
228
211
end
229
212
elseif uplo == :upper_triangular
230
213
if k <= j
231
214
return :($ asym[$ (LinearIndices (sa)[k, j])])
232
215
else
233
- return :(zero (T ))
216
+ return :(zero ($ TAsym ))
234
217
end
235
218
elseif uplo == :lower_triangular
236
219
if k >= j
237
220
return :($ asym[$ (LinearIndices (sa)[k, j])])
238
221
else
239
- return :(zero (T ))
222
+ return :(zero ($ TAsym ))
240
223
end
241
224
elseif uplo == :unit_upper_triangular
242
225
if k < j
243
226
return :($ asym[$ (LinearIndices (sa)[k, j])])
244
227
elseif k == j
245
- return :(oneunit (T ))
228
+ return :(oneunit ($ TAsym ))
246
229
else
247
- return :(zero (T ))
230
+ return :(zero ($ TAsym ))
248
231
end
249
232
elseif uplo == :unit_lower_triangular
250
233
if k > j
251
234
return :($ asym[$ (LinearIndices (sa)[k, j])])
252
235
elseif k == j
253
- return :(oneunit (T ))
236
+ return :(oneunit ($ TAsym ))
254
237
else
255
- return :(zero (T ))
238
+ return :(zero ($ TAsym ))
256
239
end
257
240
elseif uplo == :upper_hessenberg
258
241
if k <= j+ 1
259
242
return :($ asym[$ (LinearIndices (sa)[k, j])])
260
243
else
261
- return :(zero (T ))
244
+ return :(zero ($ TAsym ))
262
245
end
263
246
elseif uplo == :transpose
264
247
return :($ asym[$ (LinearIndices (reverse (sa))[j, k])])
@@ -273,9 +256,9 @@ function mul_smat_vec_exprs(sa, access_a)
273
256
return [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
274
257
end
275
258
276
- @generated function _mul (:: Size{sa} , a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: AbstractVector{Tb} ) where {sa, Ta, Tb}
259
+ @generated function _mul (:: Size{sa} , wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: AbstractVector{Tb} ) where {sa, Ta, Tb}
277
260
if sa[2 ] != 0
278
- retexpr = gen_by_access (a ) do access_a
261
+ retexpr = gen_by_access (wrapped_a ) do access_a
279
262
exprs = mul_smat_vec_exprs (sa, access_a)
280
263
return :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
281
264
end
@@ -290,17 +273,18 @@ end
290
273
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $(size (b)) " ))
291
274
end
292
275
T = promote_op (matprod,Ta,Tb)
276
+ a = mul_parent (wrapped_a)
293
277
$ retexpr
294
278
end
295
279
end
296
280
297
- @generated function _mul (:: Size{sa} , :: Size{sb} , a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} ) where {sa, sb, Ta, Tb}
281
+ @generated function _mul (:: Size{sa} , :: Size{sb} , wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} ) where {sa, sb, Ta, Tb}
298
282
if sb[1 ] != sa[2 ]
299
283
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
300
284
end
301
285
302
286
if sa[2 ] != 0
303
- retexpr = gen_by_access (a ) do access_a
287
+ retexpr = gen_by_access (wrapped_a ) do access_a
304
288
exprs = mul_smat_vec_exprs (sa, access_a)
305
289
return :(@inbounds similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
306
290
end
312
296
return quote
313
297
@_inline_meta
314
298
T = promote_op (matprod,Ta,Tb)
299
+ a = mul_parent (wrapped_a)
315
300
$ retexpr
316
301
end
317
302
end
@@ -362,28 +347,30 @@ end
362
347
end
363
348
end
364
349
365
- @generated function mul_unrolled (:: Size{sa} , :: Size{sb} , a :: StaticMatMulLike{<:Any, <:Any, Ta} , b :: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
350
+ @generated function mul_unrolled (:: Size{sa} , :: Size{sb} , wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , wrapped_b :: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
366
351
if sb[1 ] != sa[2 ]
367
352
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
368
353
end
369
354
370
355
S = Size (sa[1 ], sb[2 ])
371
356
372
357
if sa[2 ] != 0
373
- retexpr = gen_by_access (a, b ) do access_a, access_b
358
+ retexpr = gen_by_access (wrapped_a, wrapped_b ) do access_a, access_b
374
359
exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
375
360
[:($ (uplo_access (sa, :a , k1, j, access_a))* $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
376
361
) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
377
- return :((mul_result_structure (a, b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
362
+ return :((mul_result_structure (wrapped_a, wrapped_b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
378
363
end
379
364
else
380
365
exprs = [:(zero (T)) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
381
- retexpr = :(return (mul_result_structure (a, b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
366
+ retexpr = :(return (mul_result_structure (wrapped_a, wrapped_b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
382
367
end
383
368
384
369
return quote
385
370
@_inline_meta
386
371
T = promote_op (matprod,Ta,Tb)
372
+ a = mul_parent (wrapped_a)
373
+ b = mul_parent (wrapped_b)
387
374
@inbounds $ retexpr
388
375
end
389
376
end
0 commit comments