Skip to content

Commit 3e96bbc

Browse files
committed
optimized multiplication by triangular matrices
1 parent 8dd4523 commit 3e96bbc

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

src/matrix_multiply.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ end
255255
# Implementations
256256

257257
function mul_smat_vec_exprs(sa, access_a)
258-
return [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
258+
return [combine_products([:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
259259
end
260260

261261
@generated function _mul(::Size{sa}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
@@ -326,6 +326,25 @@ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTria
326326
@eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}}
327327
end
328328

329+
function combine_products(expr_list)
330+
filtered = filter(expr_list) do expr
331+
if expr.head != :call || expr.args[1] != :*
332+
error("expected call to *")
333+
end
334+
for arg in expr.args[2:end]
335+
if isa(arg, Expr) && arg.head == :call && arg.args[1] == :zero
336+
return false
337+
end
338+
end
339+
return true
340+
end
341+
if isempty(filtered)
342+
return :(zero(T))
343+
else
344+
return reduce((ex1,ex2) -> :(+($ex1,$ex2)), filtered)
345+
end
346+
end
347+
329348
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
330349
# Heuristic choice for amount of codegen
331350
if sa[1]*sa[2]*sb[2] <= 8*8*8 || !(a <: StaticMatrix) || !(b <: StaticMatrix)
@@ -362,8 +381,7 @@ end
362381

363382
if sa[2] != 0
364383
retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
365-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)),
366-
[:($(uplo_access(sa, :a, k1, j, access_a))*$(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
384+
exprs = [combine_products([:($(uplo_access(sa, :a, k1, j, access_a))*$(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
367385
) for k1 = 1:sa[1], k2 = 1:sb[2]]
368386
return :((mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...)))))
369387
end

src/matrix_multiply_add.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
6464

6565
@inline function LinearAlgebra.mul!(dest::StaticVecOrMatLike{TDest}, A::StaticVecOrMatLike{TA},
6666
B::StaticVecOrMatLike{TB}) where {TDest,TA,TB}
67-
TMul = typeof(one(TA)*one(TB)+one(TA)*one(TB))
67+
TMul = promote_op(matprod, TA, TB)
6868
return _mul!(TSize(dest), mul_parent(dest), Size(A), Size(B), A, B, NoMulAdd{TMul, TDest}())
6969
end
7070

@@ -111,7 +111,12 @@ end
111111

112112
"Obtain an expression for the linear index of var[k,j], taking transposes into account"
113113
function _lind(var::Symbol, A::Type{TSize{sa,tA}}, k::Int, j::Int) where {sa,tA}
114-
return uplo_access(sa, var, k, j, tA)
114+
ula = uplo_access(sa, var, k, j, tA)
115+
if ula.head == :call && ula.args[1] == :transpose
116+
# TODO: can this be properly fixed at all?
117+
return ula.args[2]
118+
end
119+
return ula
115120
end
116121

117122

@@ -126,9 +131,8 @@ end
126131

127132
if sa[2] != 0
128133
assign_expr = gen_by_access(wrapped_a) do access_a
129-
lhs = [:($(_lind(:c,Sc,k,col))) for k = 1:sa[1]]
130-
ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
131-
[:($(uplo_access(sa, :a, k, j, access_a)) * b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]]
134+
lhs = [_lind(:c,Sc,k,col) for k = 1:sa[1]]
135+
ab = [combine_products([:($(uplo_access(sa, :a, k, j, access_a)) * b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
132136
exprs = _muladd_expr(lhs, ab, _add)
133137

134138
return :(@inbounds $(Expr(:block, exprs...)))
@@ -221,13 +225,12 @@ end
221225
end
222226

223227
if sa[2] != 0
224-
lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]]
228+
lhs = [_lind(:c, Sc, k1, k2) for k1 = 1:sa[1], k2 = 1:sb[2]]
225229

226230
assign_expr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
227231

228-
ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
229-
[:($(uplo_access(sa, :a, k1, j, access_a)) * $(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
230-
))) for k1 = 1:sa[1], k2 = 1:sb[2]]
232+
ab = [combine_products([:($(uplo_access(sa, :a, k1, j, access_a)) * $(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
233+
) for k1 = 1:sa[1], k2 = 1:sb[2]]
231234

232235
exprs = _muladd_expr(lhs, ab, _add)
233236
return :(@inbounds $(Expr(:block, exprs...)))
@@ -246,6 +249,7 @@ end
246249
c = mul_parent(wrapped_c)
247250
a = mul_parent(wrapped_a)
248251
b = mul_parent(wrapped_b)
252+
T = promote_op(matprod,Ta,Tb)
249253
$assign_expr
250254
return c
251255
end
@@ -259,7 +263,7 @@ end
259263
end
260264

261265
# This will not work for Symmetric and Hermitian wrappers of c
262-
lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]]
266+
lhs = [_lind(:c, Sc, k1, k2) for k1 = 1:sa[1], k2 = 1:sb[2]]
263267

264268
#vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
265269

@@ -299,7 +303,7 @@ end
299303
end
300304

301305
if sa[2] != 0
302-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
306+
exprs = [combine_products([:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
303307
else
304308
exprs = [:(zero(promote_op(matprod,Ta,Tb))) for k = 1:sa[1]]
305309
end

0 commit comments

Comments
 (0)