@@ -981,6 +981,7 @@ class tinyBLAS_Q0_B16_AVX {
981
981
}
982
982
983
983
#if defined(__AVX512BF16__)
984
+ // Templated functions for gemm of dimesnions 4xN
984
985
template <int RN>
985
986
NOINLINE void gemm4xN (int64_t m0, int64_t m, int64_t n0, int64_t n) {
986
987
int64_t ytiles = (m - m0) / 4 ;
@@ -1005,6 +1006,7 @@ class tinyBLAS_Q0_B16_AVX {
1005
1006
__m256i avec3 = load (A + lda * (ii + 3 ) + l);
1006
1007
for (int64_t j = 0 ; j < RN; ++j) {
1007
1008
__m128bh db = m128bh (_mm_set1_epi16 (B[ldb * (jj + j) + l].d ));
1009
+ // Computation of product of delta values for four blocks
1008
1010
__m256 dvec = _mm256_castps128_ps256 (_mm_dpbf16_ps (zerovec, da, db));
1009
1011
dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
1010
1012
Cv[j][0 ] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
@@ -1056,7 +1058,8 @@ class tinyBLAS_Q0_B16_AVX {
1056
1058
__m256i bvec3 = load (B + ldb * (jj + 3 ) + l);
1057
1059
for (int64_t i = 0 ; i < RM; ++i) {
1058
1060
__m128bh da = m128bh (_mm_set1_epi16 ((A[lda * (ii + i) + l].d )));
1059
- __m256 dvec = _mm256_castps128_ps256 (_mm_dpbf16_ps (zerovec, da, db));
1061
+ // Computation of product of delta values for four blocks
1062
+ __m256 dvec = _mm256_castps128_ps256 (_mm_dpbf16_ps (zerovec, da, db));
1060
1063
dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
1061
1064
Cv[0 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
1062
1065
updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
0 commit comments