Skip to content

Commit fd15da0

Browse files
Reduce compile time for generic matmatmul (#52038)
This is another attempt at improving the compile time issue with generic matmatmul, hopefully improving runtime performance also. @chriselrod @jishnub There seems to be a little typo/oversight somewhere, but it shows how it could work. Locally, this reduces benchmark times from JuliaLang/julia#51812 (comment) by more than 50%. --------- Co-authored-by: Chris Elrod <elrodc@gmail.com>
1 parent ddd98c7 commit fd15da0

File tree

3 files changed

+86
-211
lines changed

3 files changed

+86
-211
lines changed

src/adjtrans.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ adjoint(A::Adjoint) = A.parent
281281
transpose(A::Transpose) = A.parent
282282
adjoint(A::Transpose{<:Real}) = A.parent
283283
transpose(A::Adjoint{<:Real}) = A.parent
284+
adjoint(A::Transpose{<:Any,<:Adjoint}) = transpose(A.parent.parent)
285+
transpose(A::Adjoint{<:Any,<:Transpose}) = adjoint(A.parent.parent)
286+
# disambiguation
287+
adjoint(A::Transpose{<:Real,<:Adjoint}) = transpose(A.parent.parent)
288+
transpose(A::Adjoint{<:Real,<:Transpose}) = A.parent
284289

285290
# printing
286291
function Base.showarg(io::IO, v::Adjoint, toplevel)
@@ -395,11 +400,16 @@ map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)...
395400
map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...))
396401
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
397402
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
403+
quasiparentc(x) = parent(parent(x)); quasiparentc(x::Number) = conj(x) # to handle numbers in the defs below
398404
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
399405
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
400406
# Hack to preserve behavior after #32122; this needs to be done with a broadcast style instead to support dotted fusion
401407
Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
402408
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
409+
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Transpose{<:Any,<:AdjointAbsVec}}...) =
410+
transpose(adjoint(broadcast((xs...) -> adjoint(transpose(f(conj.(xs)...))), quasiparentc.(tvs)...)))
411+
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Adjoint{<:Any,<:TransposeAbsVec}}...) =
412+
adjoint(transpose(broadcast((xs...) -> transpose(adjoint(f(conj.(xs)...))), quasiparentc.(tvs)...)))
403413
# TODO unify and allow mixed combinations with a broadcast style
404414

405415

src/matmul.jl

Lines changed: 62 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ julia> lmul!(F.Q, B)
337337
lmul!(A, B)
338338

339339
# THE one big BLAS dispatch
340-
@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
340+
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
341341
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
342342
if all(in(('N', 'T', 'C')), (tA, tB))
343343
if tA == 'T' && tB == 'N' && A === B
@@ -364,16 +364,16 @@ lmul!(A, B)
364364
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
365365
end
366366
end
367-
return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
367+
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
368368
end
369369

370370
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
371-
@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
371+
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
372372
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
373373
if all(in(('N', 'T', 'C')), (tA, tB))
374374
gemm_wrapper!(C, tA, tB, A, B, _add)
375375
else
376-
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
376+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
377377
end
378378
end
379379

@@ -563,11 +563,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
563563
if all(in(('N', 'T', 'C')), (tA, tB))
564564
gemm_wrapper!(C, tA, tB, A, B)
565565
else
566-
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
566+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
567567
end
568568
end
569569

570-
function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
570+
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
571571
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
572572
_add = MulAddMul()) where {T<:BlasFloat}
573573
mA, nA = lapack_size(tA, A)
@@ -604,10 +604,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
604604
stride(C, 2) >= size(C, 1))
605605
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
606606
end
607-
_generic_matmatmul!(C, tA, tB, A, B, _add)
607+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
608608
end
609609

610-
function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
610+
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
611611
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
612612
_add = MulAddMul()) where {T<:BlasReal}
613613
mA, nA = lapack_size(tA, A)
@@ -647,7 +647,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
647647
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
648648
return C
649649
end
650-
_generic_matmatmul!(C, tA, tB, A, B, _add)
650+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
651651
end
652652

653653
# blas.jl defines matmul for floats; other integer and mixed precision
@@ -764,197 +764,65 @@ end
764764

765765
const tilebufsize = 10800 # Approximately 32k/3
766766

767-
function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul)
768-
mA, nA = lapack_size(tA, A)
769-
mB, nB = lapack_size(tB, B)
770-
mC, nC = size(C)
771-
772-
if iszero(_add.alpha)
773-
return _rmul_or_fill!(C, _add.beta)
774-
end
775-
if mA == nA == mB == nB == mC == nC == 2
776-
return matmul2x2!(C, tA, tB, A, B, _add)
777-
end
778-
if mA == nA == mB == nB == mC == nC == 3
779-
return matmul3x3!(C, tA, tB, A, B, _add)
780-
end
781-
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
782-
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
783-
_generic_matmatmul!(C, tA, tB, A, B, _add)
784-
end
767+
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
768+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
785769

786-
function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
770+
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
787771
_add::MulAddMul) where {T,S,R}
788-
@assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')
789-
require_one_based_indexing(C, A, B)
790-
791-
mA, nA = lapack_size(tA, A)
792-
mB, nB = lapack_size(tB, B)
793-
if mB != nA
794-
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)"))
795-
end
796-
if size(C,1) != mA || size(C,2) != nB
797-
throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)"))
798-
end
799-
800-
if iszero(_add.alpha) || isempty(A) || isempty(B)
801-
return _rmul_or_fill!(C, _add.beta)
802-
end
803-
804-
tile_size = 0
805-
if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N')
806-
tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1)))
807-
end
808-
@inbounds begin
809-
if tile_size > 0
810-
sz = (tile_size, tile_size)
811-
Atile = Array{T}(undef, sz)
812-
Btile = Array{S}(undef, sz)
813-
814-
z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
815-
z = convert(promote_type(typeof(z1), R), z1)
816-
817-
if mA < tile_size && nA < tile_size && nB < tile_size
818-
copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA)
819-
copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB)
820-
for j = 1:nB
821-
boff = (j-1)*tile_size
822-
for i = 1:mA
823-
aoff = (i-1)*tile_size
824-
s = z
825-
for k = 1:nA
826-
s += Atile[aoff+k] * Btile[boff+k]
827-
end
828-
_modify!(_add, s, C, (i,j))
829-
end
830-
end
831-
else
832-
Ctile = Array{R}(undef, sz)
833-
for jb = 1:tile_size:nB
834-
jlim = min(jb+tile_size-1,nB)
835-
jlen = jlim-jb+1
836-
for ib = 1:tile_size:mA
837-
ilim = min(ib+tile_size-1,mA)
838-
ilen = ilim-ib+1
839-
fill!(Ctile, z)
840-
for kb = 1:tile_size:nA
841-
klim = min(kb+tile_size-1,mB)
842-
klen = klim-kb+1
843-
copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
844-
copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
845-
for j=1:jlen
846-
bcoff = (j-1)*tile_size
847-
for i = 1:ilen
848-
aoff = (i-1)*tile_size
849-
s = z
850-
for k = 1:klen
851-
s += Atile[aoff+k] * Btile[bcoff+k]
852-
end
853-
Ctile[bcoff+i] += s
854-
end
855-
end
856-
end
857-
if isone(_add.alpha) && iszero(_add.beta)
858-
copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
859-
else
860-
C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim])
861-
end
862-
end
772+
AxM = axes(A, 1)
773+
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
774+
BxK = axes(B, 1)
775+
BxN = axes(B, 2)
776+
CxM = axes(C, 1)
777+
CxN = axes(C, 2)
778+
if AxM != CxM
779+
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)"))
780+
end
781+
if AxK != BxK
782+
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)"))
783+
end
784+
if BxN != CxN
785+
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
786+
end
787+
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
788+
_rmul_or_fill!(C, _add.beta)
789+
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
790+
@inbounds for n in BxN, k in BxK
791+
Balpha = B[k,n]*_add.alpha
792+
@simd for m in AxM
793+
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
863794
end
864795
end
796+
elseif isbitstype(R) && sizeof(R) 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
797+
_rmul_or_fill!(C, _add.beta)
798+
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
799+
t = wrapperop(A)
800+
pB = parent(B)
801+
pA = parent(A)
802+
tmp = similar(C, CxN)
803+
ci = first(CxM)
804+
ta = t(_add.alpha)
805+
for i in AxM
806+
mul!(tmp, pB, view(pA, :, i))
807+
C[ci,:] .+= t.(ta .* tmp)
808+
ci += 1
809+
end
865810
else
866-
# Multiplication for non-plain-data uses the naive algorithm
867-
if tA == 'N'
868-
if tB == 'N'
869-
for i = 1:mA, j = 1:nB
870-
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
871-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
872-
for k = 1:nA
873-
Ctmp += A[i, k]*B[k, j]
874-
end
875-
_modify!(_add, Ctmp, C, (i,j))
876-
end
877-
elseif tB == 'T'
878-
for i = 1:mA, j = 1:nB
879-
z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1]))
880-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
881-
for k = 1:nA
882-
Ctmp += A[i, k] * transpose(B[j, k])
883-
end
884-
_modify!(_add, Ctmp, C, (i,j))
885-
end
886-
else
887-
for i = 1:mA, j = 1:nB
888-
z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]')
889-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
890-
for k = 1:nA
891-
Ctmp += A[i, k]*B[j, k]'
892-
end
893-
_modify!(_add, Ctmp, C, (i,j))
894-
end
895-
end
896-
elseif tA == 'T'
897-
if tB == 'N'
898-
for i = 1:mA, j = 1:nB
899-
z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j])
900-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
901-
for k = 1:nA
902-
Ctmp += transpose(A[k, i]) * B[k, j]
903-
end
904-
_modify!(_add, Ctmp, C, (i,j))
905-
end
906-
elseif tB == 'T'
907-
for i = 1:mA, j = 1:nB
908-
z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1]))
909-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
910-
for k = 1:nA
911-
Ctmp += transpose(A[k, i]) * transpose(B[j, k])
912-
end
913-
_modify!(_add, Ctmp, C, (i,j))
914-
end
915-
else
916-
for i = 1:mA, j = 1:nB
917-
z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]')
918-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
919-
for k = 1:nA
920-
Ctmp += transpose(A[k, i]) * adjoint(B[j, k])
921-
end
922-
_modify!(_add, Ctmp, C, (i,j))
923-
end
924-
end
925-
else
926-
if tB == 'N'
927-
for i = 1:mA, j = 1:nB
928-
z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j])
929-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
930-
for k = 1:nA
931-
Ctmp += A[k, i]'B[k, j]
932-
end
933-
_modify!(_add, Ctmp, C, (i,j))
934-
end
935-
elseif tB == 'T'
936-
for i = 1:mA, j = 1:nB
937-
z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1]))
938-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
939-
for k = 1:nA
940-
Ctmp += adjoint(A[k, i]) * transpose(B[j, k])
941-
end
942-
_modify!(_add, Ctmp, C, (i,j))
943-
end
944-
else
945-
for i = 1:mA, j = 1:nB
946-
z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]')
947-
Ctmp = convert(promote_type(R, typeof(z2)), z2)
948-
for k = 1:nA
949-
Ctmp += A[k, i]'B[j, k]'
950-
end
951-
_modify!(_add, Ctmp, C, (i,j))
952-
end
811+
if iszero(_add.alpha) || isempty(A) || isempty(B)
812+
return _rmul_or_fill!(C, _add.beta)
813+
end
814+
a1 = first(AxK)
815+
b1 = first(BxK)
816+
@inbounds for i in AxM, j in BxN
817+
z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j])
818+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
819+
@simd for k in AxK
820+
Ctmp = muladd(A[i, k], B[k, j], Ctmp)
953821
end
822+
_modify!(_add, Ctmp, C, (i,j))
954823
end
955824
end
956-
end # @inbounds
957-
C
825+
return C
958826
end
959827

960828

@@ -963,7 +831,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
963831
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
964832
end
965833

966-
function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
834+
Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
967835
_add::MulAddMul = MulAddMul())
968836
require_one_based_indexing(C, A, B)
969837
if !(size(A) == size(B) == size(C) == (2,2))
@@ -1030,7 +898,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
1030898
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
1031899
end
1032900

1033-
function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
901+
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
1034902
_add::MulAddMul = MulAddMul())
1035903
require_one_based_indexing(C, A, B)
1036904
if !(size(A) == size(B) == size(C) == (3,3))

0 commit comments

Comments
 (0)