diff --git a/.gitignore b/.gitignore index d14d438a25..c0885d4662 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ test/ZBLAT3.SUMM test/ZBLAT3_3M.SUMM test/SHBLAT3.SUMM test/SBBLAT3.SUMM +test/BBLAT3.SUMM test/cblat1 test/cblat2 test/cblat3 @@ -96,6 +97,7 @@ test/sblat3 test/sblat3_3m test/test_shgemm test/test_sbgemm +test/test_bgemm test/zblat1 test/zblat2 test/zblat3 diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index d26c3d534e..8391862b11 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -251,6 +251,7 @@ In chronological order: * Ye Tao * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 * [2025-02-27] Add sbgemv_n_neon kernel + * [2025-05-17] Impl prototype of BGEMM inferface * Abhishek Kumar * [2025-04-22] Optimise dot kernel for NEOVERSE V1 diff --git a/Makefile.system b/Makefile.system index bde3014cc4..9e20c132ce 100644 --- a/Makefile.system +++ b/Makefile.system @@ -276,14 +276,14 @@ SMALL_MATRIX_OPT = 1 endif ifeq ($(ARCH), arm64) GEMM_GEMV_FORWARD = 1 -GEMM_GEMV_FORWARD_BF16 = 1 +SBGEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), riscv) GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 -GEMM_GEMV_FORWARD_BF16 = 1 +SBGEMM_GEMV_FORWARD = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) @@ -293,8 +293,8 @@ ifneq ($(ONLY_CBLAS), 1) ifeq ($(GEMM_GEMV_FORWARD), 1) CCOMMON_OPT += -DGEMM_GEMV_FORWARD endif -ifeq ($(GEMM_GEMV_FORWARD_BF16), 1) -CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16 +ifeq ($(SBGEMM_GEMV_FORWARD), 1) +CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD endif endif @@ -1905,6 +1905,8 @@ export BUILD_HFLOAT16 export NO_LSX export NO_LASX +export BGEMM_UNROLL_M +export BGEMM_UNROLL_N export SBGEMM_UNROLL_M export SBGEMM_UNROLL_N export SHGEMM_UNROLL_M diff --git a/Makefile.tail b/Makefile.tail index ed2c0e5073..ddd74dcad6 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -1,3 +1,4 @@ +BBLASOBJS_P = $(BBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) @@ -12,8 +13,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) -BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) +BLASOBJS = $(SHBLASOBJS) $(BBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) +BLASOBJS_P = $(SHBLASPBJS_P) $(BBLASOBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) ifdef EXPRECISION BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -26,6 +27,7 @@ BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif $(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX +$(BBLASOBJS) $(BBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -USMALL_MATRIX_OPT $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX @@ -36,6 +38,7 @@ $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX $(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) +$(BBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) diff --git a/cblas.h b/cblas.h index 0364b216fc..f48d5da1e8 100644 --- a/cblas.h +++ b/cblas.h @@ -1,3 +1,31 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + #ifndef CBLAS_H #define CBLAS_H @@ -441,6 +469,8 @@ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy); +void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, diff --git a/cmake/system.cmake b/cmake/system.cmake index 6ad73525a6..81f1a67ad3 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -425,8 +425,8 @@ endif () if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") endif () -if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS) - set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16") +if (SBGEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) + set(CCOMMON_OPT "${CCOMMON_OPT} -DSBGEMM_GEMV_FORWARD") endif () if (SMALL_MATRIX_OPT) set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") diff --git a/common.h b/common.h index 23a08aaa98..4984d727cd 100644 --- a/common.h +++ b/common.h @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -317,7 +318,11 @@ typedef int blasint; #elif defined(BFLOAT16) #define IFLOAT bfloat16 #define XFLOAT IFLOAT -#define FLOAT float +#ifdef BGEMM +#define FLOAT bfloat16 +#else +#define FLOAT float +#endif #define SIZE 2 #define BASE_SHIFT 1 #define ZBASE_SHIFT 2 diff --git a/common_b.h b/common_b.h new file mode 100644 index 0000000000..e03f6800da --- /dev/null +++ b/common_b.h @@ -0,0 +1,85 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#ifndef COMMON_B_H +#define COMMON_B_H + +#ifndef DYNAMIC_ARCH +#define BGEMM_ONCOPY bgemm_oncopy +#define BGEMM_OTCOPY bgemm_otcopy +#define BGEMM_INCOPY bgemm_incopy +#define BGEMM_ITCOPY bgemm_itcopy + +#define BGEMM_BETA bgemm_beta +#define BGEMM_KERNEL bgemm_kernel + +#else + +#define BGEMM_ONCOPY gotoblas->bgemm_oncopy +#define BGEMM_OTCOPY gotoblas->bgemm_otcopy +#define BGEMM_INCOPY gotoblas->bgemm_incopy +#define BGEMM_ITCOPY gotoblas->bgemm_itcopy +#define BGEMM_BETA gotoblas->bgemm_beta +#define BGEMM_KERNEL gotoblas->bgemm_kernel + +#endif + +#define BGEMM_NN bgemm_nn +#define BGEMM_CN bgemm_tn +#define BGEMM_TN bgemm_tn +#define BGEMM_NC bgemm_nt +#define BGEMM_NT bgemm_nt +#define BGEMM_CC bgemm_tt +#define BGEMM_CT bgemm_tt +#define BGEMM_TC bgemm_tt +#define BGEMM_TT bgemm_tt +#define BGEMM_NR bgemm_nn +#define BGEMM_TR bgemm_tn +#define BGEMM_CR bgemm_tn +#define BGEMM_RN bgemm_nn +#define BGEMM_RT bgemm_nt +#define BGEMM_RC bgemm_nt +#define BGEMM_RR bgemm_nn + +#define BGEMM_THREAD_NN bgemm_thread_nn +#define BGEMM_THREAD_CN bgemm_thread_tn +#define BGEMM_THREAD_TN bgemm_thread_tn +#define BGEMM_THREAD_NC bgemm_thread_nt +#define BGEMM_THREAD_NT bgemm_thread_nt +#define BGEMM_THREAD_CC bgemm_thread_tt +#define BGEMM_THREAD_CT bgemm_thread_tt +#define BGEMM_THREAD_TC bgemm_thread_tt +#define BGEMM_THREAD_TT bgemm_thread_tt +#define BGEMM_THREAD_NR bgemm_thread_nn +#define BGEMM_THREAD_TR bgemm_thread_tn +#define BGEMM_THREAD_CR bgemm_thread_tn +#define BGEMM_THREAD_RN bgemm_thread_nn +#define BGEMM_THREAD_RT bgemm_thread_nt +#define BGEMM_THREAD_RC bgemm_thread_nt +#define BGEMM_THREAD_RR bgemm_thread_nn +#endif diff --git a/common_interface.h b/common_interface.h index 23d86871fc..f69baab1ca 100644 --- a/common_interface.h +++ b/common_interface.h @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -483,6 +484,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *, hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *); +void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *, + bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *); void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, diff --git a/common_level3.h b/common_level3.h index 1838b4bf6a..607eaf9402 100644 --- a/common_level3.h +++ b/common_level3.h @@ -56,6 +56,8 @@ int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); +int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, @@ -83,6 +85,10 @@ int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int bgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int bgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int bgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int bgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); @@ -511,6 +517,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); +int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); @@ -668,6 +675,11 @@ int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLAS int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int bgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); @@ -770,6 +782,11 @@ int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int bgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index b29a9c08df..22c1e14a20 100644 --- a/common_macro.h +++ b/common_macro.h @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -40,6 +41,7 @@ #define COMMON_MACRO #include "common_sh.h" +#include "common_b.h" #include "common_sb.h" #include "common_s.h" #include "common_d.h" @@ -702,8 +704,52 @@ #define GEMM_THREAD_RR SHGEMM_THREAD_NN -#elif defined(BFLOAT16) +#elif defined(BFLOAT16) && defined(BGEMM) +#define GEMM_BETA BGEMM_BETA +#define GEMM_KERNEL_N BGEMM_KERNEL +#define GEMM_KERNEL_L BGEMM_KERNEL +#define GEMM_KERNEL_R BGEMM_KERNEL +#define GEMM_KERNEL_B BGEMM_KERNEL + +#define GEMM_NN BGEMM_NN +#define GEMM_CN BGEMM_TN +#define GEMM_TN BGEMM_TN +#define GEMM_NC BGEMM_NT +#define GEMM_NT BGEMM_NT +#define GEMM_CC BGEMM_TT +#define GEMM_CT BGEMM_TT +#define GEMM_TC BGEMM_TT +#define GEMM_TT BGEMM_TT +#define GEMM_NR BGEMM_NN +#define GEMM_TR BGEMM_TN +#define GEMM_CR BGEMM_TN +#define GEMM_RN BGEMM_NN +#define GEMM_RT BGEMM_NT +#define GEMM_RC BGEMM_NT +#define GEMM_RR BGEMM_NN +#define GEMM_ONCOPY BGEMM_ONCOPY +#define GEMM_OTCOPY BGEMM_OTCOPY +#define GEMM_INCOPY BGEMM_INCOPY +#define GEMM_ITCOPY BGEMM_ITCOPY + +#define GEMM_THREAD_NN BGEMM_THREAD_NN +#define GEMM_THREAD_CN BGEMM_THREAD_TN +#define GEMM_THREAD_TN BGEMM_THREAD_TN +#define GEMM_THREAD_NC BGEMM_THREAD_NT +#define GEMM_THREAD_NT BGEMM_THREAD_NT +#define GEMM_THREAD_CC BGEMM_THREAD_TT +#define GEMM_THREAD_CT BGEMM_THREAD_TT +#define GEMM_THREAD_TC BGEMM_THREAD_TT +#define GEMM_THREAD_TT BGEMM_THREAD_TT +#define GEMM_THREAD_NR BGEMM_THREAD_NN +#define GEMM_THREAD_TR BGEMM_THREAD_TN +#define GEMM_THREAD_CR BGEMM_THREAD_TN +#define GEMM_THREAD_RN BGEMM_THREAD_NN +#define GEMM_THREAD_RT BGEMM_THREAD_NT +#define GEMM_THREAD_RC BGEMM_THREAD_NT +#define GEMM_THREAD_RR BGEMM_THREAD_NN +#elif defined(BFLOAT16) #define D_TO_BF16_K SBDTOBF16_K #define D_BF16_TO_K DBF16TOD_K #define S_TO_BF16_K SBSTOBF16_K @@ -2663,6 +2709,9 @@ || defined(ARCH_LOONGARCH64) || defined(ARCH_E2K) || defined(ARCH_ALPHA)) extern BLASLONG gemm_offset_a; extern BLASLONG gemm_offset_b; +extern BLASLONG bgemm_p; +extern BLASLONG bgemm_q; +extern BLASLONG bgemm_r; extern BLASLONG sbgemm_p; extern BLASLONG sbgemm_q; extern BLASLONG sbgemm_r; diff --git a/common_param.h b/common_param.h index f82b73a72b..503525dd2a 100644 --- a/common_param.h +++ b/common_param.h @@ -1,6 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ -/* Copyright 2023 The OpenBLAS Project. */ +/* Copyright 2023, 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -65,6 +65,10 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); #if BUILD_BFLOAT16 == 1 + int bgemm_p, bgemm_q, bgemm_r; + int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn; + int bgemm_align_k; + int sbgemm_p, sbgemm_q, sbgemm_r; int sbgemm_unroll_m, sbgemm_unroll_n, sbgemm_unroll_mn; int sbgemm_align_k; @@ -105,6 +109,14 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); int (*sbsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); int (*sbsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); + int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); + + int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*bgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*bgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*sbgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int (*sbgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); @@ -1254,6 +1266,13 @@ extern gotoblas_t *gotoblas; #endif #if (BUILD_BFLOAT16==1) +#define BGEMM_P gotoblas -> bgemm_p +#define BGEMM_Q gotoblas -> bgemm_q +#define BGEMM_R gotoblas -> bgemm_r +#define BGEMM_UNROLL_M gotoblas -> bgemm_unroll_m +#define BGEMM_UNROLL_N gotoblas -> bgemm_unroll_n +#define BGEMM_UNROLL_MN gotoblas -> bgemm_unroll_mn + #define SBGEMM_P gotoblas -> sbgemm_p #define SBGEMM_Q gotoblas -> sbgemm_q #define SBGEMM_R gotoblas -> sbgemm_r @@ -1395,6 +1414,17 @@ extern gotoblas_t *gotoblas; #endif #if (BUILD_BFLOAT16 == 1) +#define BGEMM_P BGEMM_DEFAULT_P +#define BGEMM_Q BGEMM_DEFAULT_Q +#define BGEMM_R BGEMM_DEFAULT_R +#define BGEMM_UNROLL_M BGEMM_DEFAULT_UNROLL_M +#define BGEMM_UNROLL_N BGEMM_DEFAULT_UNROLL_N +#ifdef BGEMM_DEFAULT_UNROLL_MN +#define BGEMM_UNROLL_MN BGEMM_DEFAULT_UNROLL_MN +#else +#define BGEMM_UNROLL_MN MAX((BGEMM_UNROLL_M), (BGEMM_UNROLL_N)) +#endif + #define SBGEMM_P SBGEMM_DEFAULT_P #define SBGEMM_Q SBGEMM_DEFAULT_Q #define SBGEMM_R SBGEMM_DEFAULT_R @@ -1555,6 +1585,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R SHGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#elif defined(BFLOAT16) && defined(BGEMM) +#define GEMM_P BGEMM_P +#define GEMM_Q BGEMM_Q +#define GEMM_R BGEMM_R +#define GEMM_UNROLL_M BGEMM_UNROLL_M +#define GEMM_UNROLL_N BGEMM_UNROLL_N +#define GEMM_UNROLL_MN BGEMM_UNROLL_MN +#define GEMM_DEFAULT_P BGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q BGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R BGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M BGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N BGEMM_DEFAULT_UNROLL_N #elif defined(BFLOAT16) #define GEMM_P SBGEMM_P #define GEMM_Q SBGEMM_Q diff --git a/driver/level3/Makefile b/driver/level3/Makefile index 132645a871..622996c3b8 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -1,3 +1,32 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + TOPDIR = ../.. include ../../Makefile.system @@ -20,6 +49,7 @@ USE_GEMM3M = 1 endif ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += bgemm_nn.$(SUFFIX) bgemm_nt.$(SUFFIX) bgemm_tn.$(SUFFIX) bgemm_tt.$(SUFFIX) SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) endif @@ -212,6 +242,7 @@ COMMONOBJS += syrk_thread.$(SUFFIX) ifneq ($(USE_SIMPLE_THREADED_LEVEL3), 1) ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += bgemm_thread_nn.$(SUFFIX) bgemm_thread_nt.$(SUFFIX) bgemm_thread_tn.$(SUFFIX) bgemm_thread_tt.$(SUFFIX) SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX) endif ifeq ($(BUILD_HFLOAT16),1) @@ -350,6 +381,18 @@ endif all :: +bgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +bgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +bgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +bgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -569,6 +612,18 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h beta_thread.$(SUFFIX) : beta_thread.c ../../common.h $(CC) -c $(CFLAGS) $< -o $(@F) +bgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +bgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +bgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +bgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sbgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/level3/level3.c b/driver/level3/level3.c index 5d3438450b..78bc6aa527 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -170,6 +170,22 @@ #define STOP_RPCC(COUNTER) #endif +#if defined(BUILD_BFLOAT16) +#if defined(DYNAMIC_ARCH) + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k + #else + #define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k + #endif +#else + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K BGEMM_ALIGN_K + #else + #define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K + #endif +#endif +#endif + int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ BLASLONG k, lda, ldb, ldc; @@ -307,11 +323,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG pad_min_l = min_l; #if defined(BFLOAT16) -#if defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); -#else - pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; -#endif + pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1); #endif /* First, we have to move data A to L2 cache */ diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index 5ede6153ef..cb93591ab0 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -216,6 +216,22 @@ typedef struct { #define STOP_RPCC(COUNTER) #endif +#if defined(BUILD_BFLOAT16) +#if defined(DYNAMIC_ARCH) + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k + #else + #define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k + #endif +#else + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K BGEMM_ALIGN_K + #else + #define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K + #endif +#endif +#endif + static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ IFLOAT *buffer[DIVIDE_RATE]; @@ -325,11 +341,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG pad_min_l = min_l; #if defined(BFLOAT16) -#if defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); -#else - pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; -#endif + pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1); #endif /* Determine step size in m diff --git a/exports/gensymbol b/exports/gensymbol index 3719574ea6..17fbd2877f 100755 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -51,7 +51,7 @@ blasobjsz=" zgeadd dzsum zgemmt zgemmtr" blasobjs="lsame xerbla" -bfblasobjs="sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" +bfblasobjs="bgemm sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" hfblasobjs="shgemm" cblasobjsc=" cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv diff --git a/exports/gensymbol.pl b/exports/gensymbol.pl index 5a8423697b..01f68fbb33 100644 --- a/exports/gensymbol.pl +++ b/exports/gensymbol.pl @@ -51,7 +51,7 @@ zgeadd, dzsum, zgemmt,zgemmtr); @blasobjs = (lsame, xerbla); -@bfblasobjs = (sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); +@bfblasobjs = (bgemm, sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); @hfblasobjs = (shgemm); @cblasobjsc = ( cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, diff --git a/getarch_2nd.c b/getarch_2nd.c index 8170e9cf33..2085556bd6 100644 --- a/getarch_2nd.c +++ b/getarch_2nd.c @@ -1,3 +1,31 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + #include #ifndef BUILD_KERNEL #include "config.h" @@ -17,6 +45,10 @@ typedef unsigned long BLASULONG; int main(int argc, char **argv) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { + printf("BGEMM_UNROLL_M=%d\n", BGEMM_DEFAULT_UNROLL_M); + printf("BGEMM_UNROLL_N=%d\n", BGEMM_DEFAULT_UNROLL_N); + printf("BGEMM_UNROLL_M=%d\n", BGEMM_DEFAULT_UNROLL_M); + printf("BGEMM_UNROLL_N=%d\n", BGEMM_DEFAULT_UNROLL_N); printf("SBGEMM_UNROLL_M=%d\n", SBGEMM_DEFAULT_UNROLL_M); printf("SBGEMM_UNROLL_N=%d\n", SBGEMM_DEFAULT_UNROLL_N); printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M); diff --git a/interface/Makefile b/interface/Makefile index dc154a3c4e..e14796cbbd 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -47,6 +47,7 @@ SBLAS3OBJS = \ sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) +BBLAS3OBJ = bgemm.$(SUFFIX) SBBLAS1OBJS = sbdot.$(SUFFIX) SBBLAS2OBJS = sbgemv.$(SUFFIX) SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) @@ -289,6 +290,7 @@ CSBLAS3OBJS = \ cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgemmtr.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) +CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX) @@ -393,6 +395,7 @@ override CFLAGS += -I. SBLAS1OBJS += $(CSBLAS1OBJS) SBLAS2OBJS += $(CSBLAS2OBJS) SBLAS3OBJS += $(CSBLAS3OBJS) +BBLAS3OBJ += $(CBBLAS3OBJS) SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS) @@ -412,6 +415,7 @@ SBEXTOBJS += $(CSBEXTOBJS) CBAUXOBJS += $(CXERBLAOBJ) endif +BBLASOBJS = $(BBLAS3OBJ) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) SHBLASOBJS = $(SHBLAS3OBJS) @@ -560,7 +564,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $ level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS) +level3 : $(SBBLAS3OBJS) $(BBLAS3OBJ) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ aux : $(CBAUXOBJS) @@ -1311,6 +1315,8 @@ xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c $(CC) -c $(CFLAGS) $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +bgemm.$(SUFFIX) bgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -c $(CFLAGS) $< -o $(@F) sbgemm.$(SUFFIX) sbgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -c $(CFLAGS) $< -o $(@F) sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h @@ -1979,6 +1985,8 @@ cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +cblas_bgemm.$(SUFFIX) cblas_bgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -DCBLAS -DBGEMM -c $(CFLAGS) $< -o $(@F) cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) endif diff --git a/interface/gemm.c b/interface/gemm.c index d79282e13f..c21fd988bd 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -54,8 +54,13 @@ #define ERROR_NAME "DGEMM " #define GEMV BLASFUNC(dgemv) #elif defined(BFLOAT16) +#ifdef BGEMM +#define ERROR_NAME "BGEMM " +#define GEMV BLASFUNC(bgemv) +#else #define ERROR_NAME "SBGEMM " #define GEMV BLASFUNC(sbgemv) +#endif #elif defined(HFLOAT16) #define ERROR_NAME "SHGEMM " #else @@ -579,7 +584,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS args.m, args.n, args.k, args.lda, args.ldb, args.ldc); #endif -#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) +#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD))) #if defined(ARCH_ARM64) // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c} // perform poorly in certain circumstances. We use the following boolean diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 6afb49a779..06f18e6be2 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -1,3 +1,30 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### USE_GEMM3M = 0 OS := $(shell uname) @@ -110,6 +137,23 @@ endif endif ifeq ($(BUILD_BFLOAT16), 1) +ifndef BGEMMKERNEL +BGEMM_BETA = ../generic/gemm_beta.c +BGEMMKERNEL = ../generic/gemmkernel_2x2.c +BGEMMINCOPY = ../generic/gemm_ncopy_2.c +BGEMMITCOPY = ../generic/gemm_tcopy_2.c +BGEMMONCOPY = ../generic/gemm_ncopy_2.c +BGEMMOTCOPY = ../generic/gemm_tcopy_2.c +BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) +BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) +BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) +BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) +endif +BKERNELOBJS += \ + bgemm_kernel$(TSUFFIX).$(SUFFIX) \ + $(BGEMMINCOPYOBJ) $(BGEMMITCOPYOBJ) \ + $(BGEMMONCOPYOBJ) $(BGEMMOTCOPYOBJ) + ifndef SBGEMMKERNEL SBGEMM_BETA = ../generic/gemm_beta.c SBGEMMKERNEL = ../generic/gemmkernel_2x2.c @@ -210,6 +254,7 @@ XKERNELOBJS += \ $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += $(BKERNELOBJS) SBBLASOBJS += $(SBKERNELOBJS) endif ifeq ($(BUILD_HFLOAT16),1) @@ -223,6 +268,7 @@ ZBLASOBJS += $(ZKERNELOBJS) XBLASOBJS += $(XKERNELOBJS) ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += bgemm_beta$(TSUFFIX).$(SUFFIX) SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) endif ifeq ($(BUILD_HFLOAT16),1) @@ -667,6 +713,8 @@ XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) ifeq ($(BUILD_BFLOAT16),1) +$(KDIR)bgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMM_BETA) + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif @@ -698,9 +746,22 @@ ifeq ($(ARCH), E2K) USE_TRMM = 1 endif - ifeq ($(BUILD_BFLOAT16), 1) +$(KDIR)$(BGEMMONCOPYOBJ) : $(KERNELDIR)/$(BGEMMONCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMOTCOPYOBJ) : $(KERNELDIR)/$(BGEMMOTCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ + +ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N)) +$(KDIR)$(BGEMMINCOPYOBJ) : $(KERNELDIR)/$(BGEMMINCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMITCOPYOBJ) : $(KERNELDIR)/$(BGEMMITCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)$(SBGEMMONCOPYOBJ) : $(KERNELDIR)/$(SBGEMMONCOPY) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ @@ -914,7 +975,8 @@ endif endif ifeq ($(BUILD_BFLOAT16), 1) - +$(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif @@ -2908,6 +2970,8 @@ $(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ ifeq ($(BUILD_BFLOAT16),1) +$(KDIR)bgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BGEMM_BETA) + $(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ $(KDIR)sbgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index ccb772cc7d..74e9bf9a97 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -38,11 +39,41 @@ #include "common.h" +#if defined(BFLOAT16) && defined(BGEMM) && defined(BFLOAT16CONVERSION) +static float +bfloat16tof32 (bfloat16 f16) +{ + float result = 0; + unsigned short* q = (unsigned short*)(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = f16; +#else + q[1] = f16; +#endif + return result; +} +static bfloat16 +f32tobfloat16(float f32) +{ + unsigned short* q = (unsigned short*)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) +#else +#define BF16TOF32(x) x +#define F32TOBF16(x) x +#endif + int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ - BLASLONG i, j; BLASLONG chunk, remain; FLOAT *c_offset1, *c_offset; @@ -54,18 +85,18 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, c_offset1 = c_offset; c_offset += ldc; for(i=chunk; i>0; i--){ - *(c_offset1 + 0) = ZERO; - *(c_offset1 + 1) = ZERO; - *(c_offset1 + 2) = ZERO; - *(c_offset1 + 3) = ZERO; - *(c_offset1 + 4) = ZERO; - *(c_offset1 + 5) = ZERO; - *(c_offset1 + 6) = ZERO; - *(c_offset1 + 7) = ZERO; + *(c_offset1 + 0) = F32TOBF16(ZERO); + *(c_offset1 + 1) = F32TOBF16(ZERO); + *(c_offset1 + 2) = F32TOBF16(ZERO); + *(c_offset1 + 3) = F32TOBF16(ZERO); + *(c_offset1 + 4) = F32TOBF16(ZERO); + *(c_offset1 + 5) = F32TOBF16(ZERO); + *(c_offset1 + 6) = F32TOBF16(ZERO); + *(c_offset1 + 7) = F32TOBF16(ZERO); c_offset1 += 8; } for(i=remain; i>0; i--){ - *c_offset1 = ZERO; + *c_offset1 = F32TOBF16(ZERO); c_offset1 ++; } } @@ -74,18 +105,18 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, c_offset1 = c_offset; c_offset += ldc; for(i=chunk; i>0; i--){ - *(c_offset1 + 0) *= beta; - *(c_offset1 + 1) *= beta; - *(c_offset1 + 2) *= beta; - *(c_offset1 + 3) *= beta; - *(c_offset1 + 4) *= beta; - *(c_offset1 + 5) *= beta; - *(c_offset1 + 6) *= beta; - *(c_offset1 + 7) *= beta; + *(c_offset1 + 0) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[0])); + *(c_offset1 + 1) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[1])); + *(c_offset1 + 2) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[2])); + *(c_offset1 + 3) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[3])); + *(c_offset1 + 4) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[4])); + *(c_offset1 + 5) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[5])); + *(c_offset1 + 6) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[6])); + *(c_offset1 + 7) = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[7])); c_offset1 += 8; } for(i=remain; i>0; i--){ - *c_offset1 *= beta; + *c_offset1 = F32TOBF16(BF16TOF32(beta) * BF16TOF32(c_offset1[0])); c_offset1 ++; } } diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index bf1c3ae381..add84f0431 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -12,9 +12,29 @@ bfloat16tof32 (bfloat16 f16) #endif return result; } + +static bfloat16 f32tobfloat16(float f32) { + unsigned short *q = (unsigned short *)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + +#ifdef BGEMM +#define ALPHA bfloat16tof32(alpha) #define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) #else +#define ALPHA alpha +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) x +#endif +#else +#define ALPHA alpha #define BF16TOF32(x) x +#define F32TOBF16(x) x #endif int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL @@ -25,7 +45,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, BLASLONG i,j,k; FLOAT *C0,*C1; IFLOAT *ptrba,*ptrbb; +#ifdef BGEMM + float res0,res1,res2,res3; +#else FLOAT res0,res1,res2,res3; +#endif IFLOAT load0,load1,load2,load3,load4,load5,load6,load7; for (j=0; j SBBLAT3.SUMM @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM + @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat @@ -223,6 +254,8 @@ ifeq ($(USE_OPENMP), 1) ifeq ($(BUILD_BFLOAT16),1) OMP_NUM_THREADS=2 ./test_sbgemm > SBBLAT3.SUMM @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 + OMP_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM + @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 endif ifeq ($(BUILD_SINGLE),1) OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat @@ -244,6 +277,8 @@ else ifeq ($(BUILD_BFLOAT16),1) OPENBLAS_NUM_THREADS=2 ./test_sbgemm > SBBLAT3.SUMM @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 + OPENBLAS_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM + @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat @@ -367,6 +402,9 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME) endif ifeq ($(BUILD_BFLOAT16),1) +test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME) + $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif @@ -387,7 +425,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_sbgemm sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c new file mode 100644 index 0000000000..2858b782d1 --- /dev/null +++ b/test/compare_sgemm_bgemm.c @@ -0,0 +1,224 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include "../common.h" +#include +#include + +#include + +#define SGEMM BLASFUNC(sgemm) +#define BGEMM BLASFUNC(bgemm) +#define BGEMM_LARGEST 256 + +typedef union +{ + unsigned short v; +#if defined(_AIX) + struct __attribute__((packed)) +#else + struct +#endif + { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + unsigned short s:1; + unsigned short e:8; + unsigned short m:7; +#else + unsigned short m:7; + unsigned short e:8; + unsigned short s:1; +#endif + } bits; +} bfloat16_bits; + +typedef union +{ + float v; +#if defined(_AIX) + struct __attribute__((packed)) +#else + struct +#endif + { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + uint32_t s:1; + uint32_t e:8; + uint32_t m:23; +#else + uint32_t m:23; + uint32_t e:8; + uint32_t s:1; +#endif + } bits; +} float32_bits; + +float +float16to32 (bfloat16_bits f16) +{ + float32_bits f32; + f32.bits.s = f16.bits.s; + f32.bits.e = f16.bits.e; + f32.bits.m = (uint32_t) f16.bits.m << 16; + return f32.v; +} + +bfloat16 +float32to16 (float32_bits f32) +{ + bfloat16_bits f16; + f16.bits.s = f32.bits.s; + f16.bits.e = f32.bits.e; + f16.bits.m = (f32.bits.m >> 16) & 0x7f; + return f16.v; +} + +static float truncate_float(float value) { + bfloat16_bits f16 = (bfloat16_bits)float32to16((float32_bits)value); + return float16to32(f16); +} + +void *malloc_safe(size_t size) { + if (size == 0) + return malloc(1); + else + return malloc(size); +} + +int +main (int argc, char *argv[]) +{ + blasint m, n, k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = BGEMM_LARGEST; + char transA = 'N', transB = 'N'; + float alpha = 1.0, beta = 0.0; + bfloat16 alpha_bf16 = float32to16((float32_bits)alpha); + bfloat16 beta_bf16 = float32to16((float32_bits)beta); + + for (x = 0; x <= loop; x++) + { + if ((x > 100) && (x != BGEMM_LARGEST)) continue; + m = k = n = x; + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); + bfloat16_bits *CC = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); + FLOAT *DD = (FLOAT *)malloc_safe(m * n * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + bfloat16 atmp,btmp; + blasint one=1; + + for (j = 0; j < m; j++) + { + for (i = 0; i < k; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); + AA[j * k + i].v = atmp; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); + BB[j * k + i].v = btmp; + } + } + for (y = 0; y < 4; y++) + { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(CC, 0, m * n * sizeof(bfloat16_bits)); + memset(DD, 0, m * n * sizeof(FLOAT)); + memset(C, 0, m * n * sizeof(FLOAT)); + + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + BGEMM (&transA, &transB, &m, &n, &k, &alpha_bf16, (bfloat16*) AA, + &m, (bfloat16*)BB, &k, &beta_bf16, (bfloat16*)CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + { + for (l = 0; l < k; l++) + if (transA == 'N' && transB == 'N') + { + DD[i * m + j] += + float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); + } else if (transA == 'T' && transB == 'N') + { + DD[i * m + j] += + float16to32 (AA[k * j + l]) * float16to32 (BB[l + k * i]); + } else if (transA == 'N' && transB == 'T') + { + DD[i * m + j] += + float16to32 (AA[l * m + j]) * float16to32 (BB[i + l * n]); + } else if (transA == 'T' && transB == 'T') + { + DD[i * m + j] += + float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); + } + if (fabs(float16to32(CC[i * m + j]) - truncate_float(C[i * m + j])) > 2.0) { + ret++; + } + if (fabs(float16to32(CC[i * m + j]) - truncate_float(DD[i * m + j])) > 1.0) { + ret++; + } + + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(CC); + free(DD); + } + + if (ret != 0) { + fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret); + return ret; + } +}