|
186 | 186 | function mul_result_structure(::SDiagonal, ::LowerTriangular{<:Any, <:StaticMatrix})
|
187 | 187 | return LowerTriangular
|
188 | 188 | end
|
| 189 | +function mul_result_structure(::UnitUpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) |
| 190 | + return UpperTriangular |
| 191 | +end |
| 192 | +function mul_result_structure(::UnitLowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) |
| 193 | + return LowerTriangular |
| 194 | +end |
| 195 | +function mul_result_structure(::SDiagonal, ::UnitUpperTriangular{<:Any, <:StaticMatrix}) |
| 196 | + return UpperTriangular |
| 197 | +end |
| 198 | +function mul_result_structure(::SDiagonal, ::UnitLowerTriangular{<:Any, <:StaticMatrix}) |
| 199 | + return LowerTriangular |
| 200 | +end |
189 | 201 | function mul_result_structure(::SDiagonal, ::SDiagonal)
|
190 | 202 | return Diagonal
|
191 | 203 | end
|
|
319 | 331 | if sa[2] != 0
|
320 | 332 | retexpr = gen_by_access(wrapped_a) do access_a
|
321 | 333 | exprs = mul_smat_vec_exprs(sa, access_a)
|
322 |
| - return :(@inbounds similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) |
| 334 | + return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) |
323 | 335 | end
|
324 | 336 | else
|
325 | 337 | exprs = [:(zero(T)) for k = 1:sa[1]]
|
|
353 | 365 | end
|
354 | 366 |
|
355 | 367 | _unstatic_array(::Type{TSA}) where {S, T, N, TSA<:StaticArray{S,T,N}} = AbstractArray{T,N}
|
356 |
| -for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular] |
| 368 | +for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular, UnitUpperTriangular, UnitLowerTriangular, Diagonal] |
357 | 369 | @eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}}
|
358 | 370 | end
|
359 | 371 |
|
|
378 | 390 |
|
379 | 391 | @generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
|
380 | 392 | # Heuristic choice for amount of codegen
|
381 |
| - if sa[1]*sa[2]*sb[2] <= 8*8*8 || a <: Diagonal || b <: Diagonal |
| 393 | + a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 2 : 1 |
| 394 | + b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 2 : 1 |
| 395 | + ab_tri_mul = (a == 2 && b == 2) ? 2 : 1 |
| 396 | + if sa[1]*sa[2]*sb[2] <= 8*8*8*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal |
382 | 397 | return quote
|
383 | 398 | @_inline_meta
|
384 | 399 | return mul_unrolled(Sa, Sb, a, b)
|
|
0 commit comments