@@ -64,7 +64,7 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
64
64
65
65
@inline function LinearAlgebra. mul! (dest:: StaticVecOrMatLike{TDest} , A:: StaticVecOrMatLike{TA} ,
66
66
B:: StaticVecOrMatLike{TB} ) where {TDest,TA,TB}
67
- TMul = typeof ( one (TA) * one (TB) + one (TA) * one (TB) )
67
+ TMul = promote_op (matprod, TA, TB )
68
68
return _mul! (TSize (dest), mul_parent (dest), Size (A), Size (B), A, B, NoMulAdd {TMul, TDest} ())
69
69
end
70
70
111
111
112
112
" Obtain an expression for the linear index of var[k,j], taking transposes into account"
113
113
function _lind (var:: Symbol , A:: Type{TSize{sa,tA}} , k:: Int , j:: Int ) where {sa,tA}
114
- return uplo_access (sa, var, k, j, tA)
114
+ ula = uplo_access (sa, var, k, j, tA)
115
+ if ula. head == :call && ula. args[1 ] == :transpose
116
+ # TODO : can this be properly fixed at all?
117
+ return ula. args[2 ]
118
+ end
119
+ return ula
115
120
end
116
121
117
122
126
131
127
132
if sa[2 ] != 0
128
133
assign_expr = gen_by_access (wrapped_a) do access_a
129
- lhs = [:($ (_lind (:c ,Sc,k,col))) for k = 1 : sa[1 ]]
130
- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
131
- [:($ (uplo_access (sa, :a , k, j, access_a)) * b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
134
+ lhs = [_lind (:c ,Sc,k,col) for k = 1 : sa[1 ]]
135
+ ab = [combine_products ([:($ (uplo_access (sa, :a , k, j, access_a)) * b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
132
136
exprs = _muladd_expr (lhs, ab, _add)
133
137
134
138
return :(@inbounds $ (Expr (:block , exprs... )))
@@ -221,13 +225,12 @@ end
221
225
end
222
226
223
227
if sa[2 ] != 0
224
- lhs = [:( $ ( _lind (:c , Sc, k1, k2)) ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
228
+ lhs = [_lind (:c , Sc, k1, k2) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
225
229
226
230
assign_expr = gen_by_access (wrapped_a, wrapped_b) do access_a, access_b
227
231
228
- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
229
- [:($ (uplo_access (sa, :a , k1, j, access_a)) * $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
230
- ))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
232
+ ab = [combine_products ([:($ (uplo_access (sa, :a , k1, j, access_a)) * $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
233
+ ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
231
234
232
235
exprs = _muladd_expr (lhs, ab, _add)
233
236
return :(@inbounds $ (Expr (:block , exprs... )))
246
249
c = mul_parent (wrapped_c)
247
250
a = mul_parent (wrapped_a)
248
251
b = mul_parent (wrapped_b)
252
+ T = promote_op (matprod,Ta,Tb)
249
253
$ assign_expr
250
254
return c
251
255
end
259
263
end
260
264
261
265
# This will not work for Symmetric and Hermitian wrappers of c
262
- lhs = [:( $ ( _lind (:c , Sc, k1, k2)) ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
266
+ lhs = [_lind (:c , Sc, k1, k2) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
263
267
264
268
# vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
265
269
299
303
end
300
304
301
305
if sa[2 ] != 0
302
- exprs = [reduce ((ex1,ex2) -> :( + ( $ ex1, $ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
306
+ exprs = [combine_products ( [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
303
307
else
304
308
exprs = [:(zero (promote_op (matprod,Ta,Tb))) for k = 1 : sa[1 ]]
305
309
end
0 commit comments