@@ -337,7 +337,7 @@ julia> lmul!(F.Q, B)
337
337
lmul! (A, B)
338
338
339
339
# THE one big BLAS dispatch
340
- @inline function generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
340
+ Base . @constprop :aggressive function generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
341
341
_add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat }
342
342
if all (in ((' N' , ' T' , ' C' )), (tA, tB))
343
343
if tA == ' T' && tB == ' N' && A === B
@@ -364,16 +364,16 @@ lmul!(A, B)
364
364
return BLAS. hemm! (' R' , tB == ' H' ? ' U' : ' L' , alpha, B, A, beta, C)
365
365
end
366
366
end
367
- return _generic_matmatmul! (C, ' N ' , ' N ' , wrap (A, tA), wrap (B, tB), _add)
367
+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
368
368
end
369
369
370
370
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
371
- @inline function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
371
+ Base . @constprop :aggressive function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
372
372
_add:: MulAddMul = MulAddMul ()) where {T<: BlasReal }
373
373
if all (in ((' N' , ' T' , ' C' )), (tA, tB))
374
374
gemm_wrapper! (C, tA, tB, A, B, _add)
375
375
else
376
- _generic_matmatmul! (C, ' N ' , ' N ' , wrap (A, tA), wrap (B, tB), _add)
376
+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
377
377
end
378
378
end
379
379
@@ -563,11 +563,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
563
563
if all (in ((' N' , ' T' , ' C' )), (tA, tB))
564
564
gemm_wrapper! (C, tA, tB, A, B)
565
565
else
566
- _generic_matmatmul! (C, ' N ' , ' N ' , wrap (A, tA), wrap (B, tB), _add)
566
+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
567
567
end
568
568
end
569
569
570
- function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
570
+ Base . @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
571
571
A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
572
572
_add = MulAddMul ()) where {T<: BlasFloat }
573
573
mA, nA = lapack_size (tA, A)
@@ -604,10 +604,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
604
604
stride (C, 2 ) >= size (C, 1 ))
605
605
return BLAS. gemm! (tA, tB, alpha, A, B, beta, C)
606
606
end
607
- _generic_matmatmul! (C, tA, tB, A, B , _add)
607
+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
608
608
end
609
609
610
- function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
610
+ Base . @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
611
611
A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
612
612
_add = MulAddMul ()) where {T<: BlasReal }
613
613
mA, nA = lapack_size (tA, A)
@@ -647,7 +647,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
647
647
BLAS. gemm! (tA, tB, alpha, reinterpret (T, A), B, beta, reinterpret (T, C))
648
648
return C
649
649
end
650
- _generic_matmatmul! (C, tA, tB, A, B , _add)
650
+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
651
651
end
652
652
653
653
# blas.jl defines matmul for floats; other integer and mixed precision
@@ -764,197 +764,65 @@ end
764
764
765
765
const tilebufsize = 10800 # Approximately 32k/3
766
766
767
- function generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , _add:: MulAddMul )
768
- mA, nA = lapack_size (tA, A)
769
- mB, nB = lapack_size (tB, B)
770
- mC, nC = size (C)
771
-
772
- if iszero (_add. alpha)
773
- return _rmul_or_fill! (C, _add. beta)
774
- end
775
- if mA == nA == mB == nB == mC == nC == 2
776
- return matmul2x2! (C, tA, tB, A, B, _add)
777
- end
778
- if mA == nA == mB == nB == mC == nC == 3
779
- return matmul3x3! (C, tA, tB, A, B, _add)
780
- end
781
- A, tA = tA in (' H' , ' h' , ' S' , ' s' ) ? (wrap (A, tA), ' N' ) : (A, tA)
782
- B, tB = tB in (' H' , ' h' , ' S' , ' s' ) ? (wrap (B, tB), ' N' ) : (B, tB)
783
- _generic_matmatmul! (C, tA, tB, A, B, _add)
784
- end
767
+ Base. @constprop :aggressive generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , _add:: MulAddMul ) =
768
+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
785
769
786
- function _generic_matmatmul! (C:: AbstractVecOrMat{R} , tA, tB , A:: AbstractVecOrMat{T} , B:: AbstractVecOrMat{S} ,
770
+ @noinline function _generic_matmatmul! (C:: AbstractVecOrMat{R} , A:: AbstractVecOrMat{T} , B:: AbstractVecOrMat{S} ,
787
771
_add:: MulAddMul ) where {T,S,R}
788
- @assert tA in (' N' , ' T' , ' C' ) && tB in (' N' , ' T' , ' C' )
789
- require_one_based_indexing (C, A, B)
790
-
791
- mA, nA = lapack_size (tA, A)
792
- mB, nB = lapack_size (tB, B)
793
- if mB != nA
794
- throw (DimensionMismatch (lazy " matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)" ))
795
- end
796
- if size (C,1 ) != mA || size (C,2 ) != nB
797
- throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs ($mA,$nB)" ))
798
- end
799
-
800
- if iszero (_add. alpha) || isempty (A) || isempty (B)
801
- return _rmul_or_fill! (C, _add. beta)
802
- end
803
-
804
- tile_size = 0
805
- if isbitstype (R) && isbitstype (T) && isbitstype (S) && (tA == ' N' || tB != ' N' )
806
- tile_size = floor (Int, sqrt (tilebufsize / max (sizeof (R), sizeof (S), sizeof (T), 1 )))
807
- end
808
- @inbounds begin
809
- if tile_size > 0
810
- sz = (tile_size, tile_size)
811
- Atile = Array {T} (undef, sz)
812
- Btile = Array {S} (undef, sz)
813
-
814
- z1 = zero (A[1 , 1 ]* B[1 , 1 ] + A[1 , 1 ]* B[1 , 1 ])
815
- z = convert (promote_type (typeof (z1), R), z1)
816
-
817
- if mA < tile_size && nA < tile_size && nB < tile_size
818
- copy_transpose! (Atile, 1 : nA, 1 : mA, tA, A, 1 : mA, 1 : nA)
819
- copyto! (Btile, 1 : mB, 1 : nB, tB, B, 1 : mB, 1 : nB)
820
- for j = 1 : nB
821
- boff = (j- 1 )* tile_size
822
- for i = 1 : mA
823
- aoff = (i- 1 )* tile_size
824
- s = z
825
- for k = 1 : nA
826
- s += Atile[aoff+ k] * Btile[boff+ k]
827
- end
828
- _modify! (_add, s, C, (i,j))
829
- end
830
- end
831
- else
832
- Ctile = Array {R} (undef, sz)
833
- for jb = 1 : tile_size: nB
834
- jlim = min (jb+ tile_size- 1 ,nB)
835
- jlen = jlim- jb+ 1
836
- for ib = 1 : tile_size: mA
837
- ilim = min (ib+ tile_size- 1 ,mA)
838
- ilen = ilim- ib+ 1
839
- fill! (Ctile, z)
840
- for kb = 1 : tile_size: nA
841
- klim = min (kb+ tile_size- 1 ,mB)
842
- klen = klim- kb+ 1
843
- copy_transpose! (Atile, 1 : klen, 1 : ilen, tA, A, ib: ilim, kb: klim)
844
- copyto! (Btile, 1 : klen, 1 : jlen, tB, B, kb: klim, jb: jlim)
845
- for j= 1 : jlen
846
- bcoff = (j- 1 )* tile_size
847
- for i = 1 : ilen
848
- aoff = (i- 1 )* tile_size
849
- s = z
850
- for k = 1 : klen
851
- s += Atile[aoff+ k] * Btile[bcoff+ k]
852
- end
853
- Ctile[bcoff+ i] += s
854
- end
855
- end
856
- end
857
- if isone (_add. alpha) && iszero (_add. beta)
858
- copyto! (C, ib: ilim, jb: jlim, Ctile, 1 : ilen, 1 : jlen)
859
- else
860
- C[ib: ilim, jb: jlim] .= @views _add .(Ctile[1 : ilen, 1 : jlen], C[ib: ilim, jb: jlim])
861
- end
862
- end
772
+ AxM = axes (A, 1 )
773
+ AxK = axes (A, 2 ) # we use two `axes` calls in case of `AbstractVector`
774
+ BxK = axes (B, 1 )
775
+ BxN = axes (B, 2 )
776
+ CxM = axes (C, 1 )
777
+ CxN = axes (C, 2 )
778
+ if AxM != CxM
779
+ throw (DimensionMismatch (lazy " matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)" ))
780
+ end
781
+ if AxK != BxK
782
+ throw (DimensionMismatch (lazy " matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)" ))
783
+ end
784
+ if BxN != CxN
785
+ throw (DimensionMismatch (lazy " matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)" ))
786
+ end
787
+ if isbitstype (R) && sizeof (R) ≤ 16 && ! (A isa Adjoint || A isa Transpose)
788
+ _rmul_or_fill! (C, _add. beta)
789
+ (iszero (_add. alpha) || isempty (A) || isempty (B)) && return C
790
+ @inbounds for n in BxN, k in BxK
791
+ Balpha = B[k,n]* _add. alpha
792
+ @simd for m in AxM
793
+ C[m,n] = muladd (A[m,k], Balpha, C[m,n])
863
794
end
864
795
end
796
+ elseif isbitstype (R) && sizeof (R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
797
+ _rmul_or_fill! (C, _add. beta)
798
+ (iszero (_add. alpha) || isempty (A) || isempty (B)) && return C
799
+ t = wrapperop (A)
800
+ pB = parent (B)
801
+ pA = parent (A)
802
+ tmp = similar (C, CxN)
803
+ ci = first (CxM)
804
+ ta = t (_add. alpha)
805
+ for i in AxM
806
+ mul! (tmp, pB, view (pA, :, i))
807
+ C[ci,:] .+ = t .(ta .* tmp)
808
+ ci += 1
809
+ end
865
810
else
866
- # Multiplication for non-plain-data uses the naive algorithm
867
- if tA == ' N'
868
- if tB == ' N'
869
- for i = 1 : mA, j = 1 : nB
870
- z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
871
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
872
- for k = 1 : nA
873
- Ctmp += A[i, k]* B[k, j]
874
- end
875
- _modify! (_add, Ctmp, C, (i,j))
876
- end
877
- elseif tB == ' T'
878
- for i = 1 : mA, j = 1 : nB
879
- z2 = zero (A[i, 1 ]* transpose (B[j, 1 ]) + A[i, 1 ]* transpose (B[j, 1 ]))
880
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
881
- for k = 1 : nA
882
- Ctmp += A[i, k] * transpose (B[j, k])
883
- end
884
- _modify! (_add, Ctmp, C, (i,j))
885
- end
886
- else
887
- for i = 1 : mA, j = 1 : nB
888
- z2 = zero (A[i, 1 ]* B[j, 1 ]' + A[i, 1 ]* B[j, 1 ]' )
889
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
890
- for k = 1 : nA
891
- Ctmp += A[i, k]* B[j, k]'
892
- end
893
- _modify! (_add, Ctmp, C, (i,j))
894
- end
895
- end
896
- elseif tA == ' T'
897
- if tB == ' N'
898
- for i = 1 : mA, j = 1 : nB
899
- z2 = zero (transpose (A[1 , i])* B[1 , j] + transpose (A[1 , i])* B[1 , j])
900
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
901
- for k = 1 : nA
902
- Ctmp += transpose (A[k, i]) * B[k, j]
903
- end
904
- _modify! (_add, Ctmp, C, (i,j))
905
- end
906
- elseif tB == ' T'
907
- for i = 1 : mA, j = 1 : nB
908
- z2 = zero (transpose (A[1 , i])* transpose (B[j, 1 ]) + transpose (A[1 , i])* transpose (B[j, 1 ]))
909
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
910
- for k = 1 : nA
911
- Ctmp += transpose (A[k, i]) * transpose (B[j, k])
912
- end
913
- _modify! (_add, Ctmp, C, (i,j))
914
- end
915
- else
916
- for i = 1 : mA, j = 1 : nB
917
- z2 = zero (transpose (A[1 , i])* B[j, 1 ]' + transpose (A[1 , i])* B[j, 1 ]' )
918
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
919
- for k = 1 : nA
920
- Ctmp += transpose (A[k, i]) * adjoint (B[j, k])
921
- end
922
- _modify! (_add, Ctmp, C, (i,j))
923
- end
924
- end
925
- else
926
- if tB == ' N'
927
- for i = 1 : mA, j = 1 : nB
928
- z2 = zero (A[1 , i]' * B[1 , j] + A[1 , i]' * B[1 , j])
929
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
930
- for k = 1 : nA
931
- Ctmp += A[k, i]' B[k, j]
932
- end
933
- _modify! (_add, Ctmp, C, (i,j))
934
- end
935
- elseif tB == ' T'
936
- for i = 1 : mA, j = 1 : nB
937
- z2 = zero (A[1 , i]' * transpose (B[j, 1 ]) + A[1 , i]' * transpose (B[j, 1 ]))
938
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
939
- for k = 1 : nA
940
- Ctmp += adjoint (A[k, i]) * transpose (B[j, k])
941
- end
942
- _modify! (_add, Ctmp, C, (i,j))
943
- end
944
- else
945
- for i = 1 : mA, j = 1 : nB
946
- z2 = zero (A[1 , i]' * B[j, 1 ]' + A[1 , i]' * B[j, 1 ]' )
947
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
948
- for k = 1 : nA
949
- Ctmp += A[k, i]' B[j, k]'
950
- end
951
- _modify! (_add, Ctmp, C, (i,j))
952
- end
811
+ if iszero (_add. alpha) || isempty (A) || isempty (B)
812
+ return _rmul_or_fill! (C, _add. beta)
813
+ end
814
+ a1 = first (AxK)
815
+ b1 = first (BxK)
816
+ @inbounds for i in AxM, j in BxN
817
+ z2 = zero (A[i, a1]* B[b1, j] + A[i, a1]* B[b1, j])
818
+ Ctmp = convert (promote_type (R, typeof (z2)), z2)
819
+ @simd for k in AxK
820
+ Ctmp = muladd (A[i, k], B[k, j], Ctmp)
953
821
end
822
+ _modify! (_add, Ctmp, C, (i,j))
954
823
end
955
824
end
956
- end # @inbounds
957
- C
825
+ return C
958
826
end
959
827
960
828
@@ -963,7 +831,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
963
831
matmul2x2! (similar (B, promote_op (matprod, T, S), 2 , 2 ), tA, tB, A, B)
964
832
end
965
833
966
- function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
834
+ Base . @constprop :aggressive function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
967
835
_add:: MulAddMul = MulAddMul ())
968
836
require_one_based_indexing (C, A, B)
969
837
if ! (size (A) == size (B) == size (C) == (2 ,2 ))
@@ -1030,7 +898,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
1030
898
matmul3x3! (similar (B, promote_op (matprod, T, S), 3 , 3 ), tA, tB, A, B)
1031
899
end
1032
900
1033
- function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
901
+ Base . @constprop :aggressive function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
1034
902
_add:: MulAddMul = MulAddMul ())
1035
903
require_one_based_indexing (C, A, B)
1036
904
if ! (size (A) == size (B) == size (C) == (3 ,3 ))
0 commit comments