Skip to content

Add infrastructure for BGEMM #5357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ test/ZBLAT3.SUMM
test/ZBLAT3_3M.SUMM
test/SHBLAT3.SUMM
test/SBBLAT3.SUMM
test/BBLAT3.SUMM
test/cblat1
test/cblat2
test/cblat3
Expand All @@ -96,6 +97,7 @@ test/sblat3
test/sblat3_3m
test/test_shgemm
test/test_sbgemm
test/test_bgemm
test/zblat1
test/zblat2
test/zblat3
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ In chronological order:
* Ye Tao <ye.tao@arm.com>
* [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1
* [2025-02-27] Add sbgemv_n_neon kernel
* [2025-05-17] Impl prototype of BGEMM inferface

* Abhishek Kumar <https://github.com/abhishek-iitmadras>
* [2025-04-22] Optimise dot kernel for NEOVERSE V1
Expand Down
10 changes: 6 additions & 4 deletions Makefile.system
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,14 @@ SMALL_MATRIX_OPT = 1
endif
ifeq ($(ARCH), arm64)
GEMM_GEMV_FORWARD = 1
GEMM_GEMV_FORWARD_BF16 = 1
SBGEMM_GEMV_FORWARD = 1
endif
ifeq ($(ARCH), riscv)
GEMM_GEMV_FORWARD = 1
endif
ifeq ($(ARCH), power)
GEMM_GEMV_FORWARD = 1
GEMM_GEMV_FORWARD_BF16 = 1
SBGEMM_GEMV_FORWARD = 1
endif

ifeq ($(SMALL_MATRIX_OPT), 1)
Expand All @@ -293,8 +293,8 @@ ifneq ($(ONLY_CBLAS), 1)
ifeq ($(GEMM_GEMV_FORWARD), 1)
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
endif
ifeq ($(GEMM_GEMV_FORWARD_BF16), 1)
CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16
ifeq ($(SBGEMM_GEMV_FORWARD), 1)
CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD
endif
endif

Expand Down Expand Up @@ -1905,6 +1905,8 @@ export BUILD_HFLOAT16
export NO_LSX
export NO_LASX

export BGEMM_UNROLL_M
export BGEMM_UNROLL_N
export SBGEMM_UNROLL_M
export SBGEMM_UNROLL_N
export SHGEMM_UNROLL_M
Expand Down
7 changes: 5 additions & 2 deletions Makefile.tail
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
BBLASOBJS_P = $(BBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
Expand All @@ -12,8 +13,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))

HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))

BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
BLASOBJS = $(SHBLASOBJS) $(BBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SHBLASPBJS_P) $(BBLASOBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)

ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
Expand All @@ -26,6 +27,7 @@ BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif

$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
$(BBLASOBJS) $(BBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -USMALL_MATRIX_OPT
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
Expand All @@ -36,6 +38,7 @@ $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX

$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(BBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
Expand Down
30 changes: 30 additions & 0 deletions cblas.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#ifndef CBLAS_H
#define CBLAS_H

Expand Down Expand Up @@ -441,6 +469,8 @@ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);

void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc);
void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
Expand Down
4 changes: 2 additions & 2 deletions cmake/system.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ endif ()
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
endif ()
if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16")
if (SBGEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DSBGEMM_GEMV_FORWARD")
endif ()
if (SMALL_MATRIX_OPT)
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
Expand Down
7 changes: 6 additions & 1 deletion common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*********************************************************************/
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* Copyright 2025 The OpenBLAS Project. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
Expand Down Expand Up @@ -317,7 +318,11 @@ typedef int blasint;
#elif defined(BFLOAT16)
#define IFLOAT bfloat16
#define XFLOAT IFLOAT
#define FLOAT float
#ifdef BGEMM
#define FLOAT bfloat16
#else
#define FLOAT float
#endif
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
Expand Down
85 changes: 85 additions & 0 deletions common_b.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#ifndef COMMON_B_H
#define COMMON_B_H

#ifndef DYNAMIC_ARCH
#define BGEMM_ONCOPY bgemm_oncopy
#define BGEMM_OTCOPY bgemm_otcopy
#define BGEMM_INCOPY bgemm_incopy
#define BGEMM_ITCOPY bgemm_itcopy

#define BGEMM_BETA bgemm_beta
#define BGEMM_KERNEL bgemm_kernel

#else

#define BGEMM_ONCOPY gotoblas->bgemm_oncopy
#define BGEMM_OTCOPY gotoblas->bgemm_otcopy
#define BGEMM_INCOPY gotoblas->bgemm_incopy
#define BGEMM_ITCOPY gotoblas->bgemm_itcopy
#define BGEMM_BETA gotoblas->bgemm_beta
#define BGEMM_KERNEL gotoblas->bgemm_kernel

#endif

#define BGEMM_NN bgemm_nn
#define BGEMM_CN bgemm_tn
#define BGEMM_TN bgemm_tn
#define BGEMM_NC bgemm_nt
#define BGEMM_NT bgemm_nt
#define BGEMM_CC bgemm_tt
#define BGEMM_CT bgemm_tt
#define BGEMM_TC bgemm_tt
#define BGEMM_TT bgemm_tt
#define BGEMM_NR bgemm_nn
#define BGEMM_TR bgemm_tn
#define BGEMM_CR bgemm_tn
#define BGEMM_RN bgemm_nn
#define BGEMM_RT bgemm_nt
#define BGEMM_RC bgemm_nt
#define BGEMM_RR bgemm_nn

#define BGEMM_THREAD_NN bgemm_thread_nn
#define BGEMM_THREAD_CN bgemm_thread_tn
#define BGEMM_THREAD_TN bgemm_thread_tn
#define BGEMM_THREAD_NC bgemm_thread_nt
#define BGEMM_THREAD_NT bgemm_thread_nt
#define BGEMM_THREAD_CC bgemm_thread_tt
#define BGEMM_THREAD_CT bgemm_thread_tt
#define BGEMM_THREAD_TC bgemm_thread_tt
#define BGEMM_THREAD_TT bgemm_thread_tt
#define BGEMM_THREAD_NR bgemm_thread_nn
#define BGEMM_THREAD_TR bgemm_thread_tn
#define BGEMM_THREAD_CR bgemm_thread_tn
#define BGEMM_THREAD_RN bgemm_thread_nn
#define BGEMM_THREAD_RT bgemm_thread_nt
#define BGEMM_THREAD_RC bgemm_thread_nt
#define BGEMM_THREAD_RR bgemm_thread_nn
#endif
3 changes: 3 additions & 0 deletions common_interface.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*********************************************************************/
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* Copyright 2025 The OpenBLAS Project. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
Expand Down Expand Up @@ -483,6 +484,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint

void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *,
bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *);
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
Expand Down
17 changes: 17 additions & 0 deletions common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand Down Expand Up @@ -83,6 +85,10 @@ int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int bgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int bgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int bgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int bgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
Expand Down Expand Up @@ -511,6 +517,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);

int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
Expand Down Expand Up @@ -668,6 +675,11 @@ int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLAS
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int bgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int bgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int bgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int bgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);

int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
Expand Down Expand Up @@ -770,6 +782,11 @@ int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int bgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int bgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int bgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int bgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);

int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
Expand Down
51 changes: 50 additions & 1 deletion common_macro.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*********************************************************************/
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* Copyright 2025 The OpenBLAS Project. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
Expand Down Expand Up @@ -40,6 +41,7 @@
#define COMMON_MACRO

#include "common_sh.h"
#include "common_b.h"
#include "common_sb.h"
#include "common_s.h"
#include "common_d.h"
Expand Down Expand Up @@ -702,8 +704,52 @@
#define GEMM_THREAD_RR SHGEMM_THREAD_NN


#elif defined(BFLOAT16)
#elif defined(BFLOAT16) && defined(BGEMM)
#define GEMM_BETA BGEMM_BETA
#define GEMM_KERNEL_N BGEMM_KERNEL
#define GEMM_KERNEL_L BGEMM_KERNEL
#define GEMM_KERNEL_R BGEMM_KERNEL
#define GEMM_KERNEL_B BGEMM_KERNEL

#define GEMM_NN BGEMM_NN
#define GEMM_CN BGEMM_TN
#define GEMM_TN BGEMM_TN
#define GEMM_NC BGEMM_NT
#define GEMM_NT BGEMM_NT
#define GEMM_CC BGEMM_TT
#define GEMM_CT BGEMM_TT
#define GEMM_TC BGEMM_TT
#define GEMM_TT BGEMM_TT
#define GEMM_NR BGEMM_NN
#define GEMM_TR BGEMM_TN
#define GEMM_CR BGEMM_TN
#define GEMM_RN BGEMM_NN
#define GEMM_RT BGEMM_NT
#define GEMM_RC BGEMM_NT
#define GEMM_RR BGEMM_NN
#define GEMM_ONCOPY BGEMM_ONCOPY
#define GEMM_OTCOPY BGEMM_OTCOPY
#define GEMM_INCOPY BGEMM_INCOPY
#define GEMM_ITCOPY BGEMM_ITCOPY

#define GEMM_THREAD_NN BGEMM_THREAD_NN
#define GEMM_THREAD_CN BGEMM_THREAD_TN
#define GEMM_THREAD_TN BGEMM_THREAD_TN
#define GEMM_THREAD_NC BGEMM_THREAD_NT
#define GEMM_THREAD_NT BGEMM_THREAD_NT
#define GEMM_THREAD_CC BGEMM_THREAD_TT
#define GEMM_THREAD_CT BGEMM_THREAD_TT
#define GEMM_THREAD_TC BGEMM_THREAD_TT
#define GEMM_THREAD_TT BGEMM_THREAD_TT
#define GEMM_THREAD_NR BGEMM_THREAD_NN
#define GEMM_THREAD_TR BGEMM_THREAD_TN
#define GEMM_THREAD_CR BGEMM_THREAD_TN
#define GEMM_THREAD_RN BGEMM_THREAD_NN
#define GEMM_THREAD_RT BGEMM_THREAD_NT
#define GEMM_THREAD_RC BGEMM_THREAD_NT
#define GEMM_THREAD_RR BGEMM_THREAD_NN

#elif defined(BFLOAT16)
#define D_TO_BF16_K SBDTOBF16_K
#define D_BF16_TO_K DBF16TOD_K
#define S_TO_BF16_K SBSTOBF16_K
Expand Down Expand Up @@ -2663,6 +2709,9 @@
|| defined(ARCH_LOONGARCH64) || defined(ARCH_E2K) || defined(ARCH_ALPHA))
extern BLASLONG gemm_offset_a;
extern BLASLONG gemm_offset_b;
extern BLASLONG bgemm_p;
extern BLASLONG bgemm_q;
extern BLASLONG bgemm_r;
extern BLASLONG sbgemm_p;
extern BLASLONG sbgemm_q;
extern BLASLONG sbgemm_r;
Expand Down
Loading
Loading