Skip to content

Commit 0ff9b55

Browse files
committed
slight adjustments to matrix multiplication
1 parent e6667bf commit 0ff9b55

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/matrix_multiply.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,18 @@ end
186186
function mul_result_structure(::SDiagonal, ::LowerTriangular{<:Any, <:StaticMatrix})
187187
return LowerTriangular
188188
end
189+
function mul_result_structure(::UnitUpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal)
190+
return UpperTriangular
191+
end
192+
function mul_result_structure(::UnitLowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal)
193+
return LowerTriangular
194+
end
195+
function mul_result_structure(::SDiagonal, ::UnitUpperTriangular{<:Any, <:StaticMatrix})
196+
return UpperTriangular
197+
end
198+
function mul_result_structure(::SDiagonal, ::UnitLowerTriangular{<:Any, <:StaticMatrix})
199+
return LowerTriangular
200+
end
189201
function mul_result_structure(::SDiagonal, ::SDiagonal)
190202
return Diagonal
191203
end
@@ -319,7 +331,7 @@ end
319331
if sa[2] != 0
320332
retexpr = gen_by_access(wrapped_a) do access_a
321333
exprs = mul_smat_vec_exprs(sa, access_a)
322-
return :(@inbounds similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
334+
return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
323335
end
324336
else
325337
exprs = [:(zero(T)) for k = 1:sa[1]]
@@ -353,7 +365,7 @@ end
353365
end
354366

355367
_unstatic_array(::Type{TSA}) where {S, T, N, TSA<:StaticArray{S,T,N}} = AbstractArray{T,N}
356-
for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular]
368+
for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular, UnitUpperTriangular, UnitLowerTriangular, Diagonal]
357369
@eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}}
358370
end
359371

@@ -378,7 +390,10 @@ end
378390

379391
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
380392
# Heuristic choice for amount of codegen
381-
if sa[1]*sa[2]*sb[2] <= 8*8*8 || a <: Diagonal || b <: Diagonal
393+
a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 2 : 1
394+
b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 2 : 1
395+
ab_tri_mul = (a == 2 && b == 2) ? 2 : 1
396+
if sa[1]*sa[2]*sb[2] <= 8*8*8*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal
382397
return quote
383398
@_inline_meta
384399
return mul_unrolled(Sa, Sb, a, b)

src/matrix_multiply_add.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ end
189189
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat && a <: Union{StaticMatrix,Transpose} && b <: Union{StaticMatrix,Transpose}
190190

191191
mult_dim = multiplied_dimension(a,b)
192-
if mult_dim < 4*4*4 || a <: Diagonal || b <: Diagonal
192+
a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 2 : 1
193+
b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 2 : 1
194+
ab_tri_mul = (a == 2 && b == 2) ? 2 : 1
195+
if mult_dim < 4*4*4*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal
193196
return quote
194197
@_inline_meta
195198
muladd_unrolled_all!(Sc, c, Sa, Sb, a, b, _add)

0 commit comments

Comments
 (0)