Skip to content

Commit 0a84e1d

Browse files
authored
Inline generic_matmatmul! branch in strided triangular matmul (#1262)
For combinations of strided matrices and strided triangular matrices, we would end up taking the methods defined in `LinearAlgebra`, so we may avoid the constant-propagation and hardcode the `_generic_matmatmul!` call. This improves TTFX, as the no-op but expensive-to-compile `wrap` call is elided. ```julia julia> using LinearAlgebra julia> A = zeros(4,4); julia> @time A * UpperTriangular(A); 0.458913 seconds (1.22 M allocations: 59.769 MiB, 51.63% gc time, 97.84% compilation time: 4% of which was recompilation) # master 0.077198 seconds (174.52 k allocations: 8.683 MiB, 92.75% compilation time) # this PR ```
1 parent a3c2681 commit 0a84e1d

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/triangular.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ const LowerOrUnitLowerTriangular{T,S<:AbstractMatrix{T}} = Union{LowerTriangular
144144
const UpperOrLowerTriangular{T,S<:AbstractMatrix{T}} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}
145145
const UnitUpperOrUnitLowerTriangular{T,S<:AbstractMatrix{T}} = Union{UnitUpperTriangular{T,S}, UnitLowerTriangular{T,S}}
146146

147+
const UpperOrLowerTriangularStrided{T,S<:StridedMatrix{T}} = UpperOrLowerTriangular{T,S}
148+
147149
uppertriangular(M) = UpperTriangular(M)
148150
lowertriangular(M) = LowerTriangular(M)
149151

@@ -1155,11 +1157,20 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix),
11551157
if isone(alpha) && iszero(beta)
11561158
return _trimul!(C, A, B)
11571159
else
1158-
return generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta)
1160+
return generic_matmatmul_NN!(C, A, B, alpha, beta)
11591161
end
11601162
end
11611163
end
11621164

1165+
generic_matmatmul_NN!(C, A, B, alpha, beta) = generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta)
1166+
# Optimization for strided matrices, where we know that _generic_matmatmul! will be taken
1167+
for (TA, TB) in ((:UpperOrLowerTriangularStrided, :StridedMatrix),
1168+
(:StridedMatrix, :UpperOrLowerTriangularStrided),
1169+
(:UpperOrLowerTriangularStrided, :UpperOrLowerTriangularStrided)
1170+
)
1171+
@eval generic_matmatmul_NN!(C, A::$TA, B::$TB, alpha, beta) = _generic_matmatmul!(C, A, B, alpha, beta)
1172+
end
1173+
11631174
ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = _ldiv!(C, A, B)
11641175
# generic fallback for AbstractTriangular, directs to 2-arg [l/r]div!
11651176
_ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) =

0 commit comments

Comments
 (0)