@@ -7,7 +7,9 @@ import LinearAlgebra: BlasFloat, matprod, mul!
7
7
const StaticMatMulLike{s1, s2, T} = Union{
8
8
StaticMatrix{s1, s2, T},
9
9
Symmetric{T, <: StaticMatrix{s1, s2, T} },
10
- Hermitian{T, <: StaticMatrix{s1, s2, T} }}
10
+ Hermitian{T, <: StaticMatrix{s1, s2, T} },
11
+ LowerTriangular{T, <: StaticMatrix{s1, s2, T} },
12
+ UpperTriangular{T, <: StaticMatrix{s1, s2, T} }}
11
13
12
14
@inline * (A:: StaticMatMulLike , B:: AbstractVector ) = _mul (Size (A), A, B)
13
15
@inline * (A:: StaticMatMulLike , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
@@ -39,48 +41,75 @@ function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, as
39
41
end
40
42
end
41
43
end
42
- function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type{<:StaticMatrix} )
43
- return expr_gen (:any , :any )
44
+ function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :a )
45
+ return expr_gen (:upper_triangular )
46
+ end
47
+ function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
48
+ return expr_gen (:lower_triangular )
44
49
end
45
50
function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
46
- return gen_by_access (a) do access_a
47
- return quote
48
- return $ (gen_by_access (b, :b ) do access_b
49
- expr_gen (:any , access_b)
50
- end )
51
- end
51
+ return quote
52
+ return $ (gen_by_access (b, :b ) do access_b
53
+ expr_gen (:any , access_b)
54
+ end )
52
55
end
53
56
end
54
57
function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , b:: Type )
55
- return gen_by_access (a) do access_a
56
- return quote
57
- if a. uplo == ' U'
58
- return $ (gen_by_access (b, :b ) do access_b
59
- expr_gen (:up , access_b)
60
- end )
61
- else
62
- return $ (gen_by_access (b, :b ) do access_b
63
- expr_gen (:lo , access_b)
64
- end )
65
- end
58
+ return quote
59
+ if a. uplo == ' U'
60
+ return $ (gen_by_access (b, :b ) do access_b
61
+ expr_gen (:up , access_b)
62
+ end )
63
+ else
64
+ return $ (gen_by_access (b, :b ) do access_b
65
+ expr_gen (:lo , access_b)
66
+ end )
66
67
end
67
68
end
68
69
end
69
70
function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , b:: Type )
70
- return gen_by_access (a) do access_a
71
- return quote
72
- if a. uplo == ' U'
73
- return $ (gen_by_access (b, :b ) do access_b
74
- expr_gen (:up_herm , access_b)
75
- end )
76
- else
77
- return $ (gen_by_access (b, :b ) do access_b
78
- expr_gen (:lo_herm , access_b)
79
- end )
80
- end
71
+ return quote
72
+ if a. uplo == ' U'
73
+ return $ (gen_by_access (b, :b ) do access_b
74
+ expr_gen (:up_herm , access_b)
75
+ end )
76
+ else
77
+ return $ (gen_by_access (b, :b ) do access_b
78
+ expr_gen (:lo_herm , access_b)
79
+ end )
81
80
end
82
81
end
83
82
end
83
+ function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
84
+ return quote
85
+ return $ (gen_by_access (b, :b ) do access_b
86
+ expr_gen (:upper_triangular , access_b)
87
+ end )
88
+ end
89
+ end
90
+ function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
91
+ return quote
92
+ return $ (gen_by_access (b, :b ) do access_b
93
+ expr_gen (:lower_triangular , access_b)
94
+ end )
95
+ end
96
+ end
97
+
98
+ """
99
+ mul_result_structure(a::Type, b::Type)
100
+
101
+ Get a structure wrapper that should be applied to the result of multiplication of matrices
102
+ of given types (a*b).
103
+ """
104
+ function mul_result_structure (a, b)
105
+ return identity
106
+ end
107
+ function mul_result_structure (:: UpperTriangular{<:Any, <:StaticMatrix} , :: UpperTriangular{<:Any, <:StaticMatrix} )
108
+ return UpperTriangular
109
+ end
110
+ function mul_result_structure (:: LowerTriangular{<:Any, <:StaticMatrix} , :: LowerTriangular{<:Any, <:StaticMatrix} )
111
+ return LowerTriangular
112
+ end
84
113
85
114
function uplo_access (sa, asym, k, j, uplo)
86
115
if uplo == :any
@@ -92,7 +121,7 @@ function uplo_access(sa, asym, k, j, uplo)
92
121
return :($ asym[$ (LinearIndices (sa)[j, k])])
93
122
end
94
123
elseif uplo == :lo
95
- if j <= k
124
+ if k >= j
96
125
return :($ asym[$ (LinearIndices (sa)[k, j])])
97
126
else
98
127
return :($ asym[$ (LinearIndices (sa)[j, k])])
@@ -104,11 +133,23 @@ function uplo_access(sa, asym, k, j, uplo)
104
133
return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
105
134
end
106
135
elseif uplo == :lo_herm
107
- if j <= k
136
+ if k >= j
108
137
return :($ asym[$ (LinearIndices (sa)[k, j])])
109
138
else
110
139
return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
111
140
end
141
+ elseif uplo == :upper_triangular
142
+ if k <= j
143
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
144
+ else
145
+ return :(zero (T))
146
+ end
147
+ elseif uplo == :lower_triangular
148
+ if k >= j
149
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
150
+ else
151
+ return :(zero (T))
152
+ end
112
153
end
113
154
end
114
155
147
188
if sa[2 ] != 0
148
189
retexpr = gen_by_access (a) do access_a
149
190
exprs = mul_smat_vec_exprs (sa, access_a)
150
- return :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
191
+ return :(@inbounds similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
151
192
end
152
193
else
153
194
exprs = [:(zero (T)) for k = 1 : sa[1 ]]
195
236
end
196
237
end
197
238
198
- @generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: Union{SizedMatrix{T}, MMatrix{T}, MArray{T}} , b:: Union{SizedMatrix{T}, MMatrix{T}, MArray{T}} ) where {sa, sb, T <: BlasFloat }
199
- S = Size (sa[1 ], sb[2 ])
200
-
201
- # Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
202
- if sa[1 ]* sa[2 ]* sb[2 ] >= 14 * 14 * 14
203
- Sa = TSize {size(S),false} ()
204
- Sb = TSize {sa,false} ()
205
- Sc = TSize {sb,false} ()
206
- _add = MulAddMul (true ,false )
207
- return quote
208
- @_inline_meta
209
- C = similar (a, T, $ S)
210
- mul_blas! ($ Sa, C, $ Sa, $ Sb, a, b, $ _add)
211
- return C
212
- end
213
- elseif sa[1 ]* sa[2 ]* sb[2 ] < 8 * 8 * 8
214
- return quote
215
- @_inline_meta
216
- return mul_unrolled (Sa, Sb, a, b)
217
- end
218
- elseif sa[1 ] <= 14 && sa[2 ] <= 14 && sb[2 ] <= 14
219
- return quote
220
- @_inline_meta
221
- return similar_type (a, T, $ S)(mul_unrolled_chunks (Sa, Sb, a, b))
222
- end
223
- else
224
- return quote
225
- @_inline_meta
226
- return mul_loop (Sa, Sb, a, b)
227
- end
228
- end
229
- end
230
-
231
239
@generated function mul_unrolled (:: Size{sa} , :: Size{sb} , a:: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
232
240
if sb[1 ] != sa[2 ]
233
241
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
@@ -240,17 +248,17 @@ end
240
248
exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
241
249
[:($ (uplo_access (sa, :a , k1, j, access_a))* $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
242
250
) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
243
- return :(@inbounds return similar_type (a, T, $ S)(tuple ($ (exprs... ))))
251
+ return :(( mul_result_structure (a, b))( similar_type (a, T, $ S)(tuple ($ (exprs... ) ))))
244
252
end
245
253
else
246
254
exprs = [:(zero (T)) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
247
- retexpr = :(@inbounds return similar_type (a, T, $ S)(tuple ($ (exprs... ))))
255
+ retexpr = :(return ( mul_result_structure (a, b))( similar_type (a, T, $ S)(tuple ($ (exprs... ) ))))
248
256
end
249
257
250
258
return quote
251
259
@_inline_meta
252
260
T = promote_op (matprod,Ta,Tb)
253
- $ retexpr
261
+ @inbounds $ retexpr
254
262
end
255
263
end
256
264
0 commit comments