@@ -24,19 +24,19 @@ Should pair with `parent`.
24
24
"""
25
25
struct TSize{S,T}
26
26
function TSize {S,T} () where {S,T}
27
- new {S::Tuple{Vararg{StaticDimension}},T::Bool } ()
27
+ new {S::Tuple{Vararg{StaticDimension}},T::Symbol } ()
28
28
end
29
29
end
30
- TSize (A:: Type{<:Transpose{<:Any,<:StaticArray}} ) = TSize {size(A),true} ()
31
- TSize (A:: Type{<:Adjoint{<:Real,<:StaticArray}} ) = TSize {size(A),true} () # can't handle complex adjoints yet
32
- TSize (A:: Type{<:StaticArray} ) = TSize {size(A),false} ()
30
+ TSize (A:: Type{<:StaticArrayLike} ) = TSize {size(A), gen_by_access(identity, A)} ()
33
31
TSize (A:: StaticArrayLike ) = TSize (typeof (A))
34
- TSize (S:: Size{s} , T= false ) where s = TSize {s,T} ()
32
+ TSize (S:: Size{s} , T= :any ) where s = TSize {s,T} ()
35
33
TSize (s:: Number ) = TSize (Size (s))
36
- istranpose (:: TSize{<:Any,T} ) where T = T
34
+ istranspose (:: TSize{<:Any,T} ) where T = (T === :transpose )
37
35
size (:: TSize{S} ) where S = S
38
36
Size (:: TSize{S} ) where S = Size {S} ()
39
- Base. transpose (:: TSize{S,T} ) where {S,T} = TSize {reverse(S),!T} ()
37
+ access_type (:: TSize{<:Any,T} ) where T = T
38
+ Base. transpose (:: TSize{S,:transpose} ) where {S,T} = TSize {reverse(S),:any} ()
39
+ Base. transpose (:: TSize{S,:any} ) where {S,T} = TSize {reverse(S),:transpose} ()
40
40
41
41
# Get the parent of transposed arrays, or the array itself if it has no parent
42
42
# Different from Base.parent because we only want to get rid of Transpose and Adjoint
97
97
" Obtain an expression for the linear index of var[k,j], taking transposes into account"
98
98
@inline _lind (A:: Type{<:TSize} , k:: Int , j:: Int ) = _lind (:a , A, k, j)
99
99
function _lind (var:: Symbol , A:: Type{TSize{sa,tA}} , k:: Int , j:: Int ) where {sa,tA}
100
- if tA
101
- return :($ var[$ (LinearIndices (reverse (sa))[j, k])])
102
- else
103
- return :($ var[$ (LinearIndices (sa)[k, j])])
104
- end
100
+ return uplo_access (sa, var, k, j, tA)
105
101
end
106
102
103
+
104
+
107
105
# Matrix-vector multiplication
108
106
@generated function _mul! (Sc:: TSize{sc} , c:: StaticVecOrMatLike , Sa:: TSize{sa} , Sb:: TSize{sb} ,
109
107
a:: StaticMatrix , b:: StaticVector , _add:: MulAddMul ,
@@ -133,14 +131,21 @@ end
133
131
end
134
132
135
133
# Outer product
136
- @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , :: TSize{sa,false } , :: TSize{sb,true } ,
134
+ @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , :: TSize{sa,:any } , tsb :: Union{ TSize{sb,:transpose},TSize{sb,:adjoint} } ,
137
135
a:: StaticVector , b:: StaticVector , _add:: MulAddMul ) where {sa, sb, sc}
138
136
if sc[1 ] != sa[1 ] || sc[2 ] != sb[2 ]
139
137
throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
140
138
end
141
139
140
+ conjugate_b = isa (tsb, TSize{sb,:adjoint })
141
+
142
142
lhs = [:(c[$ (LinearIndices (sc)[i,j])]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
143
- ab = [:(a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
143
+ if conjugate_b
144
+ ab = [:(a[$ i] * adjoint (b[$ j])) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
145
+ else
146
+ ab = [:(a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
147
+ end
148
+
144
149
exprs = _muladd_expr (lhs, ab, _add)
145
150
146
151
return quote
@@ -267,17 +272,18 @@ end
267
272
@inline _get_raw_data (A:: SizedArray ) = A. data
268
273
@inline _get_raw_data (A:: StaticArray ) = A
269
274
270
- function mul_blas! (:: TSize{<:Any,false} , c:: StaticMatrix , :: TSize{<:Any,tA} , :: TSize{<:Any,tB} ,
271
- a:: StaticMatrix , b:: StaticMatrix , _add:: MulAddMul ) where {tA,tB}
272
- mat_char (tA) = tA ? ' T' : ' N'
275
+ function mul_blas! (:: TSize{<:Any,:any} , c:: StaticMatrix ,
276
+ Sa:: Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}} , Sb:: Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}} ,
277
+ a:: StaticMatrix , b:: StaticMatrix , _add:: MulAddMul )
278
+ mat_char (s) = istranspose (s) ? ' T' : ' N'
273
279
T = eltype (a)
274
280
A = _get_raw_data (a)
275
281
B = _get_raw_data (b)
276
282
C = _get_raw_data (c)
277
- BLAS. gemm! (mat_char (tA ), mat_char (tB ), T (alpha (_add)), A, B, T (beta (_add)), C)
283
+ BLAS. gemm! (mat_char (Sa ), mat_char (Sb ), T (alpha (_add)), A, B, T (beta (_add)), C)
278
284
end
279
285
280
286
# if C is transposed, transpose the entire expression
281
- @inline mul_blas! (Sc:: TSize{<:Any,true } , c:: StaticMatrix , Sa:: TSize , Sb:: TSize ,
287
+ @inline mul_blas! (Sc:: TSize{<:Any,:transpose } , c:: StaticMatrix , Sa:: TSize , Sb:: TSize ,
282
288
a:: StaticMatrix , b:: StaticMatrix , _add:: MulAddMul ) =
283
289
mul_blas! (transpose (Sc), c, transpose (Sb), transpose (Sa), b, a, _add)
0 commit comments