@@ -4,23 +4,129 @@ import LinearAlgebra: BlasFloat, matprod, mul!
4
4
# Manage dispatch of * and mul!
5
5
# TODO Adjoint? (Inner product?)
6
6
7
- @inline * (A:: StaticMatrix , B:: AbstractVector ) = _mul (Size (A), A, B)
8
- @inline * (A:: StaticMatrix , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
9
- @inline * (A:: StaticMatrix , B:: StaticMatrix ) = _mul (Size (A), Size (B), A, B)
10
- @inline * (A:: StaticVector , B:: StaticMatrix ) = * (reshape (A, Size (Size (A)[1 ], 1 )), B)
7
+ const StaticMatMulLike{s1, s2, T} = Union{
8
+ StaticMatrix{s1, s2, T},
9
+ Symmetric{T, <: StaticMatrix{s1, s2, T} },
10
+ Hermitian{T, <: StaticMatrix{s1, s2, T} }}
11
+
12
+ @inline * (A:: StaticMatMulLike , B:: AbstractVector ) = _mul (Size (A), A, B)
13
+ @inline * (A:: StaticMatMulLike , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
14
+ @inline * (A:: StaticMatMulLike , B:: StaticMatMulLike ) = _mul (Size (A), Size (B), A, B)
15
+ @inline * (A:: StaticVector , B:: StaticMatMulLike ) = * (reshape (A, Size (Size (A)[1 ], 1 )), B)
11
16
@inline * (A:: StaticVector , B:: Transpose{<:Any, <:StaticVector} ) = _mul (Size (A), Size (B), A, B)
12
17
@inline * (A:: StaticVector , B:: Adjoint{<:Any, <:StaticVector} ) = _mul (Size (A), Size (B), A, B)
13
18
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
14
19
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
15
20
21
+ function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , asym = :a )
22
+ return expr_gen (:any )
23
+ end
24
+ function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :a )
25
+ return quote
26
+ if $ (asym). uplo == ' U'
27
+ $ (expr_gen (:up ))
28
+ else
29
+ $ (expr_gen (:lo ))
30
+ end
31
+ end
32
+ end
33
+ function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :a )
34
+ return quote
35
+ if $ (asym). uplo == ' U'
36
+ $ (expr_gen (:up_herm ))
37
+ else
38
+ $ (expr_gen (:lo_herm ))
39
+ end
40
+ end
41
+ end
42
+ function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type{<:StaticMatrix} )
43
+ return expr_gen (:any , :any )
44
+ end
45
+ function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
46
+ return gen_by_access (a) do access_a
47
+ return quote
48
+ return $ (gen_by_access (b, :b ) do access_b
49
+ expr_gen (:any , access_b)
50
+ end )
51
+ end
52
+ end
53
+ end
54
+ function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , b:: Type )
55
+ return gen_by_access (a) do access_a
56
+ return quote
57
+ if a. uplo == ' U'
58
+ return $ (gen_by_access (b, :b ) do access_b
59
+ expr_gen (:up , access_b)
60
+ end )
61
+ else
62
+ return $ (gen_by_access (b, :b ) do access_b
63
+ expr_gen (:lo , access_b)
64
+ end )
65
+ end
66
+ end
67
+ end
68
+ end
69
+ function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , b:: Type )
70
+ return gen_by_access (a) do access_a
71
+ return quote
72
+ if a. uplo == ' U'
73
+ return $ (gen_by_access (b, :b ) do access_b
74
+ expr_gen (:up_herm , access_b)
75
+ end )
76
+ else
77
+ return $ (gen_by_access (b, :b ) do access_b
78
+ expr_gen (:lo_herm , access_b)
79
+ end )
80
+ end
81
+ end
82
+ end
83
+ end
84
+
85
+ function uplo_access (sa, asym, k, j, uplo)
86
+ if uplo == :any
87
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
88
+ elseif uplo == :up
89
+ if k <= j
90
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
91
+ else
92
+ return :($ asym[$ (LinearIndices (sa)[j, k])])
93
+ end
94
+ elseif uplo == :lo
95
+ if j <= k
96
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
97
+ else
98
+ return :($ asym[$ (LinearIndices (sa)[j, k])])
99
+ end
100
+ elseif uplo == :up_herm
101
+ if k <= j
102
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
103
+ else
104
+ return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
105
+ end
106
+ elseif uplo == :lo_herm
107
+ if j <= k
108
+ return :($ asym[$ (LinearIndices (sa)[k, j])])
109
+ else
110
+ return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
111
+ end
112
+ end
113
+ end
16
114
17
115
# Implementations
18
116
19
- @generated function _mul (:: Size{sa} , a:: StaticMatrix{<:Any, <:Any, Ta} , b:: AbstractVector{Tb} ) where {sa, Ta, Tb}
117
+ function mul_smat_vec_exprs (sa, access_a)
118
+ return [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
119
+ end
120
+
121
+ @generated function _mul (:: Size{sa} , a:: StaticMatMulLike{<:Any, <:Any, Ta} , b:: AbstractVector{Tb} ) where {sa, Ta, Tb}
20
122
if sa[2 ] != 0
21
- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k, j])]* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
123
+ retexpr = gen_by_access (a) do access_a
124
+ exprs = mul_smat_vec_exprs (sa, access_a)
125
+ return :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
126
+ end
22
127
else
23
128
exprs = [:(zero (T)) for k = 1 : sa[1 ]]
129
+ retexpr = :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
24
130
end
25
131
26
132
return quote
@@ -29,28 +135,33 @@ import LinearAlgebra: BlasFloat, matprod, mul!
29
135
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $(size (b)) " ))
30
136
end
31
137
T = promote_op (matprod,Ta,Tb)
32
- @inbounds return similar_type (b, T, Size (sa[ 1 ]))( tuple ( $ (exprs ... )))
138
+ $ retexpr
33
139
end
34
140
end
35
141
36
- @generated function _mul (:: Size{sa} , :: Size{sb} , a:: StaticMatrix {<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} ) where {sa, sb, Ta, Tb}
142
+ @generated function _mul (:: Size{sa} , :: Size{sb} , a:: StaticMatMulLike {<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} ) where {sa, sb, Ta, Tb}
37
143
if sb[1 ] != sa[2 ]
38
144
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
39
145
end
40
146
41
147
if sa[2 ] != 0
42
- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k, j])]* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
148
+ retexpr = gen_by_access (a) do access_a
149
+ exprs = mul_smat_vec_exprs (sa, access_a)
150
+ return :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
151
+ end
43
152
else
44
153
exprs = [:(zero (T)) for k = 1 : sa[1 ]]
154
+ retexpr = :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
45
155
end
46
156
47
157
return quote
48
158
@_inline_meta
49
159
T = promote_op (matprod,Ta,Tb)
50
- @inbounds return similar_type (b, T, Size (sa[ 1 ]))( tuple ( $ (exprs ... )))
160
+ $ retexpr
51
161
end
52
162
end
53
163
164
+
54
165
# outer product
55
166
@generated function _mul (:: Size{sa} , :: Size{sb} , a:: StaticVector{<: Any, Ta} ,
56
167
b:: Union{Transpose{Tb, <:StaticVector}, Adjoint{Tb, <:StaticVector}} ) where {sa, sb, Ta, Tb}
64
175
end
65
176
end
66
177
67
- @generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatrix {<:Any, <:Any, Ta} , b:: StaticMatrix {<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
178
+ @generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatMulLike {<:Any, <:Any, Ta} , b:: StaticMatMulLike {<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
68
179
# Heuristic choice for amount of codegen
69
180
if sa[1 ]* sa[2 ]* sb[2 ] <= 8 * 8 * 8
70
181
return quote
@@ -117,27 +228,32 @@ end
117
228
end
118
229
end
119
230
120
- @generated function mul_unrolled (:: Size{sa} , :: Size{sb} , a:: StaticMatrix {<:Any, <:Any, Ta} , b:: StaticMatrix {<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
231
+ @generated function mul_unrolled (:: Size{sa} , :: Size{sb} , a:: StaticMatMulLike {<:Any, <:Any, Ta} , b:: StaticMatMulLike {<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
121
232
if sb[1 ] != sa[2 ]
122
233
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
123
234
end
124
235
125
236
S = Size (sa[1 ], sb[2 ])
126
237
127
238
if sa[2 ] != 0
128
- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k1, j])]* b[$ (LinearIndices (sb)[j, k2])]) for j = 1 : sa[2 ]]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
239
+ retexpr = gen_by_access (a, b) do access_a, access_b
240
+ exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
241
+ [:($ (uplo_access (sa, :a , k1, j, access_a))* $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
242
+ ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
243
+ return :(@inbounds return similar_type (a, T, $ S)(tuple ($ (exprs... ))))
244
+ end
129
245
else
130
246
exprs = [:(zero (T)) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
247
+ retexpr = :(@inbounds return similar_type (a, T, $ S)(tuple ($ (exprs... ))))
131
248
end
132
249
133
250
return quote
134
251
@_inline_meta
135
252
T = promote_op (matprod,Ta,Tb)
136
- @inbounds return similar_type (a, T, $ S)( tuple ( $ (exprs ... )))
253
+ $ retexpr
137
254
end
138
255
end
139
256
140
-
141
257
@generated function mul_loop (:: Size{sa} , :: Size{sb} , a:: StaticMatrix{<:Any, <:Any, Ta} , b:: StaticMatrix{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
142
258
if sb[1 ] != sa[2 ]
143
259
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
0 commit comments