Skip to content

Commit e5952ce

Browse files
Mousiustaoye9
andcommitted
Add infrastructure for BGEMM
Setting up all the infrastructure for BGEMM support in OpenBLAS, hopefully I found all the right places. Derived mostly from the previous work done in OpenMathLib#5287 Co-authored-by: Ye Tao <ye.tao@arm.com>
1 parent 36c2589 commit e5952ce

23 files changed

+798
-57
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ test/ZBLAT3.SUMM
8282
test/ZBLAT3_3M.SUMM
8383
test/SHBLAT3.SUMM
8484
test/SBBLAT3.SUMM
85+
test/BBLAT3.SUMM
8586
test/cblat1
8687
test/cblat2
8788
test/cblat3
@@ -96,6 +97,7 @@ test/sblat3
9697
test/sblat3_3m
9798
test/test_shgemm
9899
test/test_sbgemm
100+
test/test_bgemm
99101
test/zblat1
100102
test/zblat2
101103
test/zblat3

CONTRIBUTORS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ In chronological order:
251251
* Ye Tao <ye.tao@arm.com>
252252
* [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1
253253
* [2025-02-27] Add sbgemv_n_neon kernel
254+
* [2025-05-17] Impl prototype of BGEMM inferface
254255

255256
* Abhishek Kumar <https://github.com/abhishek-iitmadras>
256-
* [2025-04-22] Optimise dot kernel for NEOVERSE V1
257+
* [2025-04-22] Optimise dot kernel for NEOVERSE V1

Makefile.system

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,14 @@ SMALL_MATRIX_OPT = 1
276276
endif
277277
ifeq ($(ARCH), arm64)
278278
GEMM_GEMV_FORWARD = 1
279-
GEMM_GEMV_FORWARD_BF16 = 1
279+
SBGEMM_GEMV_FORWARD = 1
280280
endif
281281
ifeq ($(ARCH), riscv)
282282
GEMM_GEMV_FORWARD = 1
283283
endif
284284
ifeq ($(ARCH), power)
285285
GEMM_GEMV_FORWARD = 1
286-
GEMM_GEMV_FORWARD_BF16 = 1
286+
SBGEMM_GEMV_FORWARD = 1
287287
endif
288288

289289
ifeq ($(SMALL_MATRIX_OPT), 1)
@@ -293,8 +293,8 @@ ifneq ($(ONLY_CBLAS), 1)
293293
ifeq ($(GEMM_GEMV_FORWARD), 1)
294294
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
295295
endif
296-
ifeq ($(GEMM_GEMV_FORWARD_BF16), 1)
297-
CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16
296+
ifeq ($(SBGEMM_GEMV_FORWARD), 1)
297+
CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD
298298
endif
299299
endif
300300

@@ -1905,6 +1905,8 @@ export BUILD_HFLOAT16
19051905
export NO_LSX
19061906
export NO_LASX
19071907

1908+
export BGEMM_UNROLL_M
1909+
export BGEMM_UNROLL_N
19081910
export SBGEMM_UNROLL_M
19091911
export SBGEMM_UNROLL_N
19101912
export SHGEMM_UNROLL_M

Makefile.tail

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
BBLASOBJS_P = $(BBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
12
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
23
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
34
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -12,8 +13,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
1213

1314
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
1415

15-
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
16-
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
16+
BLASOBJS = $(SHBLASOBJS) $(BBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
17+
BLASOBJS_P = $(SHBLASPBJS_P) $(BBLASOBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
1718

1819
ifdef EXPRECISION
1920
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -26,6 +27,7 @@ BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
2627
endif
2728

2829
$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
30+
$(BBLASOBJS) $(BBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -USMALL_MATRIX_OPT
2931
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
3032
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
3133
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -36,6 +38,7 @@ $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
3638
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
3739

3840
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
41+
$(BBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3942
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4043
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4144
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)

cblas.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,31 @@
1+
/***************************************************************************
2+
* Copyright (c) 2025, The OpenBLAS Project
3+
* All rights reserved.
4+
* Redistribution and use in source and binary forms, with or without
5+
* modification, are permitted provided that the following conditions are
6+
* met:
7+
* 1. Redistributions of source code must retain the above copyright
8+
* notice, this list of conditions and the following disclaimer.
9+
* 2. Redistributions in binary form must reproduce the above copyright
10+
* notice, this list of conditions and the following disclaimer in
11+
* the documentation and/or other materials provided with the
12+
* distribution.
13+
* 3. Neither the name of the OpenBLAS project nor the names of
14+
* its contributors may be used to endorse or promote products
15+
* derived from this software without specific prior written permission.
16+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26+
* POSSIBILITY OF SUCH DAMAGE.
27+
* *****************************************************************************/
28+
129
#ifndef CBLAS_H
230
#define CBLAS_H
331

@@ -441,6 +469,8 @@ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE
441469
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
442470
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);
443471

472+
void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
473+
OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc);
444474
void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
445475
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
446476
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,

cmake/system.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,8 @@ endif ()
425425
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
426426
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
427427
endif ()
428-
if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS)
429-
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16")
428+
if (SBGEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
429+
set(CCOMMON_OPT "${CCOMMON_OPT} -DSBGEMM_GEMV_FORWARD")
430430
endif ()
431431
if (SMALL_MATRIX_OPT)
432432
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")

common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*********************************************************************/
22
/* Copyright 2009, 2010 The University of Texas at Austin. */
3+
/* Copyright 2025 The OpenBLAS Project. */
34
/* All rights reserved. */
45
/* */
56
/* Redistribution and use in source and binary forms, with or */
@@ -317,7 +318,11 @@ typedef int blasint;
317318
#elif defined(BFLOAT16)
318319
#define IFLOAT bfloat16
319320
#define XFLOAT IFLOAT
320-
#define FLOAT float
321+
#ifdef BGEMM
322+
#define FLOAT bfloat16
323+
#else
324+
#define FLOAT float
325+
#endif
321326
#define SIZE 2
322327
#define BASE_SHIFT 1
323328
#define ZBASE_SHIFT 2

common_b.h

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/***************************************************************************
2+
* Copyright (c) 2025, The OpenBLAS Project
3+
* All rights reserved.
4+
* Redistribution and use in source and binary forms, with or without
5+
* modification, are permitted provided that the following conditions are
6+
* met:
7+
* 1. Redistributions of source code must retain the above copyright
8+
* notice, this list of conditions and the following disclaimer.
9+
* 2. Redistributions in binary form must reproduce the above copyright
10+
* notice, this list of conditions and the following disclaimer in
11+
* the documentation and/or other materials provided with the
12+
* distribution.
13+
* 3. Neither the name of the OpenBLAS project nor the names of
14+
* its contributors may be used to endorse or promote products
15+
* derived from this software without specific prior written permission.
16+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26+
* POSSIBILITY OF SUCH DAMAGE.
27+
* *****************************************************************************/
28+
29+
#ifndef COMMON_B_H
30+
#define COMMON_B_H
31+
32+
#ifndef DYNAMIC_ARCH
33+
#define BGEMM_ONCOPY bgemm_oncopy
34+
#define BGEMM_OTCOPY bgemm_otcopy
35+
#define BGEMM_INCOPY bgemm_incopy
36+
#define BGEMM_ITCOPY bgemm_itcopy
37+
38+
#define BGEMM_BETA bgemm_beta
39+
#define BGEMM_KERNEL bgemm_kernel
40+
41+
#else
42+
43+
#define BGEMM_ONCOPY gotoblas->bgemm_oncopy
44+
#define BGEMM_OTCOPY gotoblas->bgemm_otcopy
45+
#define BGEMM_INCOPY gotoblas->bgemm_incopy
46+
#define BGEMM_ITCOPY gotoblas->bgemm_itcopy
47+
#define BGEMM_BETA gotoblas->bgemm_beta
48+
#define BGEMM_KERNEL gotoblas->bgemm_kernel
49+
50+
#endif
51+
52+
#define BGEMM_NN bgemm_nn
53+
#define BGEMM_CN bgemm_tn
54+
#define BGEMM_TN bgemm_tn
55+
#define BGEMM_NC bgemm_nt
56+
#define BGEMM_NT bgemm_nt
57+
#define BGEMM_CC bgemm_tt
58+
#define BGEMM_CT bgemm_tt
59+
#define BGEMM_TC bgemm_tt
60+
#define BGEMM_TT bgemm_tt
61+
#define BGEMM_NR bgemm_nn
62+
#define BGEMM_TR bgemm_tn
63+
#define BGEMM_CR bgemm_tn
64+
#define BGEMM_RN bgemm_nn
65+
#define BGEMM_RT bgemm_nt
66+
#define BGEMM_RC bgemm_nt
67+
#define BGEMM_RR bgemm_nn
68+
69+
#define BGEMM_THREAD_NN bgemm_thread_nn
70+
#define BGEMM_THREAD_CN bgemm_thread_tn
71+
#define BGEMM_THREAD_TN bgemm_thread_tn
72+
#define BGEMM_THREAD_NC bgemm_thread_nt
73+
#define BGEMM_THREAD_NT bgemm_thread_nt
74+
#define BGEMM_THREAD_CC bgemm_thread_tt
75+
#define BGEMM_THREAD_CT bgemm_thread_tt
76+
#define BGEMM_THREAD_TC bgemm_thread_tt
77+
#define BGEMM_THREAD_TT bgemm_thread_tt
78+
#define BGEMM_THREAD_NR bgemm_thread_nn
79+
#define BGEMM_THREAD_TR bgemm_thread_tn
80+
#define BGEMM_THREAD_CR bgemm_thread_tn
81+
#define BGEMM_THREAD_RN bgemm_thread_nn
82+
#define BGEMM_THREAD_RT bgemm_thread_nt
83+
#define BGEMM_THREAD_RC bgemm_thread_nt
84+
#define BGEMM_THREAD_RR bgemm_thread_nn
85+
#endif

common_interface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*********************************************************************/
22
/* Copyright 2009, 2010 The University of Texas at Austin. */
3+
/* Copyright 2025 The OpenBLAS Project. */
34
/* All rights reserved. */
45
/* */
56
/* Redistribution and use in source and binary forms, with or */
@@ -483,6 +484,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
483484

484485
void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
485486
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
487+
void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *,
488+
bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *);
486489
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
487490
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
488491
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,

common_level3.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
5656

5757
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
5858
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
59+
int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16,
60+
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
5961
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
6062
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
6163
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
@@ -83,6 +85,10 @@ int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b
8385
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
8486
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
8587
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
88+
int bgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
89+
int bgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
90+
int bgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
91+
int bgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
8692
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
8793
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
8894
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
@@ -511,6 +517,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
511517
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
512518

513519
int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
520+
int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG);
514521
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
515522
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
516523
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
@@ -668,6 +675,11 @@ int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLAS
668675
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
669676
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
670677

678+
int bgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
679+
int bgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
680+
int bgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
681+
int bgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
682+
671683
int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
672684
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
673685
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -770,6 +782,11 @@ int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16
770782
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
771783
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
772784

785+
int bgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
786+
int bgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
787+
int bgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
788+
int bgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
789+
773790
int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
774791
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
775792
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);

common_macro.h

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*********************************************************************/
22
/* Copyright 2009, 2010 The University of Texas at Austin. */
3+
/* Copyright 2025 The OpenBLAS Project. */
34
/* All rights reserved. */
45
/* */
56
/* Redistribution and use in source and binary forms, with or */
@@ -40,6 +41,7 @@
4041
#define COMMON_MACRO
4142

4243
#include "common_sh.h"
44+
#include "common_b.h"
4345
#include "common_sb.h"
4446
#include "common_s.h"
4547
#include "common_d.h"
@@ -702,8 +704,52 @@
702704
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
703705

704706

705-
#elif defined(BFLOAT16)
707+
#elif defined(BFLOAT16) && defined(BGEMM)
708+
#define GEMM_BETA BGEMM_BETA
709+
#define GEMM_KERNEL_N BGEMM_KERNEL
710+
#define GEMM_KERNEL_L BGEMM_KERNEL
711+
#define GEMM_KERNEL_R BGEMM_KERNEL
712+
#define GEMM_KERNEL_B BGEMM_KERNEL
713+
714+
#define GEMM_NN BGEMM_NN
715+
#define GEMM_CN BGEMM_TN
716+
#define GEMM_TN BGEMM_TN
717+
#define GEMM_NC BGEMM_NT
718+
#define GEMM_NT BGEMM_NT
719+
#define GEMM_CC BGEMM_TT
720+
#define GEMM_CT BGEMM_TT
721+
#define GEMM_TC BGEMM_TT
722+
#define GEMM_TT BGEMM_TT
723+
#define GEMM_NR BGEMM_NN
724+
#define GEMM_TR BGEMM_TN
725+
#define GEMM_CR BGEMM_TN
726+
#define GEMM_RN BGEMM_NN
727+
#define GEMM_RT BGEMM_NT
728+
#define GEMM_RC BGEMM_NT
729+
#define GEMM_RR BGEMM_NN
730+
#define GEMM_ONCOPY BGEMM_ONCOPY
731+
#define GEMM_OTCOPY BGEMM_OTCOPY
732+
#define GEMM_INCOPY BGEMM_INCOPY
733+
#define GEMM_ITCOPY BGEMM_ITCOPY
734+
735+
#define GEMM_THREAD_NN BGEMM_THREAD_NN
736+
#define GEMM_THREAD_CN BGEMM_THREAD_TN
737+
#define GEMM_THREAD_TN BGEMM_THREAD_TN
738+
#define GEMM_THREAD_NC BGEMM_THREAD_NT
739+
#define GEMM_THREAD_NT BGEMM_THREAD_NT
740+
#define GEMM_THREAD_CC BGEMM_THREAD_TT
741+
#define GEMM_THREAD_CT BGEMM_THREAD_TT
742+
#define GEMM_THREAD_TC BGEMM_THREAD_TT
743+
#define GEMM_THREAD_TT BGEMM_THREAD_TT
744+
#define GEMM_THREAD_NR BGEMM_THREAD_NN
745+
#define GEMM_THREAD_TR BGEMM_THREAD_TN
746+
#define GEMM_THREAD_CR BGEMM_THREAD_TN
747+
#define GEMM_THREAD_RN BGEMM_THREAD_NN
748+
#define GEMM_THREAD_RT BGEMM_THREAD_NT
749+
#define GEMM_THREAD_RC BGEMM_THREAD_NT
750+
#define GEMM_THREAD_RR BGEMM_THREAD_NN
706751

752+
#elif defined(BFLOAT16)
707753
#define D_TO_BF16_K SBDTOBF16_K
708754
#define D_BF16_TO_K DBF16TOD_K
709755
#define S_TO_BF16_K SBSTOBF16_K
@@ -2663,6 +2709,9 @@
26632709
|| defined(ARCH_LOONGARCH64) || defined(ARCH_E2K) || defined(ARCH_ALPHA))
26642710
extern BLASLONG gemm_offset_a;
26652711
extern BLASLONG gemm_offset_b;
2712+
extern BLASLONG bgemm_p;
2713+
extern BLASLONG bgemm_q;
2714+
extern BLASLONG bgemm_r;
26662715
extern BLASLONG sbgemm_p;
26672716
extern BLASLONG sbgemm_q;
26682717
extern BLASLONG sbgemm_r;

0 commit comments

Comments
 (0)