Skip to content

Commit d96daa2

Browse files
authored
Merge pull request #5290 from Srangrang/develop
Add support for FP16 to openBLAS and shgemm on RISCV
2 parents fdc1c32 + 3b1ac29 commit d96daa2

37 files changed

+2522
-92
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ endif ()
152152
if (NOT DEFINED BUILD_BFLOAT16)
153153
set (BUILD_BFLOAT16 false)
154154
endif ()
155+
if (NOT DEFINED BUILD_HFLOAT16)
156+
set (BUILD_HFLOAT16 false)
157+
endif ()
155158
# set which float types we want to build for
156159
if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16)
157160
# if none are defined, build for all

Makefile.prebuild

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ TARGET_FLAGS = -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d
6464
endif
6565

6666
ifeq ($(TARGET), RISCV64_ZVL256B)
67-
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
67+
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
6868
endif
6969

7070
ifeq ($(TARGET), RISCV64_ZVL128B)
71-
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
71+
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
7272
endif
7373

7474
ifeq ($(TARGET), RISCV64_GENERIC)

Makefile.riscv64

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ CCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh_zvl512b -mabi=lp64d
77
FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static
88
endif
99
ifeq ($(CORE), RISCV64_ZVL256B)
10-
CCOMMON_OPT += -march=rv64imafdcv_zvl256b -mabi=lp64d
11-
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
10+
CCOMMON_OPT += -march=rv64imafdcv_zvl256b_zvfh_zfh -mabi=lp64d
11+
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
1212
endif
1313
ifeq ($(CORE), RISCV64_ZVL128B)
14-
CCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
15-
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
14+
CCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
15+
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
1616
endif
1717
ifeq ($(CORE), RISCV64_GENERIC)
1818
CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d

Makefile.rule

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ COMMON_PROF = -pg
308308
# If you want to enable the experimental BFLOAT16 support
309309
# BUILD_BFLOAT16 = 1
310310

311+
# If you want to enable the experimental HFLOAT16 support
312+
# BUILD_HFLOAT16 = 1
311313

312314
# Set the thread number threshold beyond which the job array for the threaded level3 BLAS
313315
# will be allocated on the heap rather than the stack. (This array alone requires

Makefile.system

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,9 @@ endif
15561556
ifeq ($(BUILD_BFLOAT16), 1)
15571557
CCOMMON_OPT += -DBUILD_BFLOAT16
15581558
endif
1559+
ifeq ($(BUILD_HFLOAT16), 1)
1560+
CCOMMON_OPT += -DBUILD_HFLOAT16
1561+
endif
15591562
ifeq ($(BUILD_SINGLE), 1)
15601563
CCOMMON_OPT += -DBUILD_SINGLE=1
15611564
endif
@@ -1898,11 +1901,14 @@ export TARGET_CORE
18981901
export NO_AVX512
18991902
export NO_AVX2
19001903
export BUILD_BFLOAT16
1904+
export BUILD_HFLOAT16
19011905
export NO_LSX
19021906
export NO_LASX
19031907

19041908
export SBGEMM_UNROLL_M
19051909
export SBGEMM_UNROLL_N
1910+
export SHGEMM_UNROLL_M
1911+
export SHGEMM_UNROLL_N
19061912
export SGEMM_UNROLL_M
19071913
export SGEMM_UNROLL_N
19081914
export DGEMM_UNROLL_M

Makefile.tail

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
2+
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
23
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
34
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
45
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
1112

1213
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
1314

14-
BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
15-
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
15+
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
16+
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
1617

1718
ifdef EXPRECISION
1819
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
2425
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
2526
endif
2627

28+
$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
2729
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
2830
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
2931
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
3335
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
3436
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
3537

38+
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3639
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3740
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3841
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)

benchmark/Makefile

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,15 @@ GOTO_LAPACK_TARGETS=
5656
endif
5757

5858
ifeq ($(BUILD_BFLOAT16),1)
59-
GOTO_HALF_TARGETS=sbgemm.goto
59+
GOTO_BFLOAT_TARGETS=sbgemm.goto
6060
else
61-
GOTO_HALF_TARGETS=
61+
GOTO_BFLOAT_TARGETS=
62+
endif
63+
64+
ifeq ($(BUILD_HFLOAT16),1)
65+
GOTO_HFLOAT_TARGETS=shgemm.goto
66+
else
67+
GOTO_HFLOAT_TARGETS=
6268
endif
6369

6470
ifeq ($(OSNAME), WINNT)
@@ -104,7 +110,7 @@ goto :: slinpack.goto dlinpack.goto clinpack.goto zlinpack.goto \
104110
spotrf.goto dpotrf.goto cpotrf.goto zpotrf.goto \
105111
ssymm.goto dsymm.goto csymm.goto zsymm.goto \
106112
somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \
107-
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto $(GOTO_HALF_TARGETS)
113+
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto $(GOTO_BFLOAT_TARGETS) $(GOTO_HFLOAT_TARGETS)
108114

109115
acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \
110116
scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \
@@ -278,7 +284,7 @@ goto :: sgemm.goto dgemm.goto cgemm.goto zgemm.goto \
278284
smin.goto dmin.goto \
279285
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto \
280286
somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \
281-
snrm2.goto dnrm2.goto scnrm2.goto dznrm2.goto $(GOTO_LAPACK_TARGETS) $(GOTO_HALF_TARGETS)
287+
snrm2.goto dnrm2.goto scnrm2.goto dznrm2.goto $(GOTO_LAPACK_TARGETS) $(GOTO_BFLOAT_TARGETS) $(GOTO_HFLOAT_TARGETS)
282288

283289
acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \
284290
scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \
@@ -633,6 +639,11 @@ sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME)
633639
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
634640
endif
635641

642+
ifeq ($(BUILD_HFLOAT16),1)
643+
shgemm.goto : shgemm.$(SUFFIX) ../$(LIBNAME)
644+
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
645+
endif
646+
636647
sgemm.goto : sgemm.$(SUFFIX) ../$(LIBNAME)
637648
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
638649

@@ -2960,7 +2971,12 @@ zcholesky.$(SUFFIX) : cholesky.c
29602971

29612972
ifeq ($(BUILD_BFLOAT16),1)
29622973
sbgemm.$(SUFFIX) : gemm.c
2963-
$(CC) $(CFLAGS) -c -DHALF -UCOMPLEX -UDOUBLE -o $(@F) $^
2974+
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
2975+
endif
2976+
2977+
ifeq ($(BUILD_HFLOAT16),1)
2978+
shgemm.$(SUFFIX) : gemm.c
2979+
$(CC) $(CFLAGS) -c -DHFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
29642980
endif
29652981

29662982
sgemm.$(SUFFIX) : gemm.c

benchmark/gemm.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333

3434
#ifdef DOUBLE
3535
#define GEMM BLASFUNC(dgemm)
36-
#elif defined(HALF)
36+
#elif defined(BFLOAT16)
3737
#define GEMM BLASFUNC(sbgemm)
38+
#undef IFLOAT
39+
#define IFLOAT bfloat16
40+
#elif defined(HFLOAT16)
41+
#define GEMM BLASFUNC(shgemm)
42+
#undef IFLOAT
43+
#define IFLOAT hfloat16
3844
#else
3945
#define GEMM BLASFUNC(sgemm)
46+
#define IFLOAT float
4047
#endif
4148

4249
#else

cblas.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
446446
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
447447
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
448448

449+
/*** FLOAT16 extensions ***/
450+
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
451+
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
452+
449453
#ifdef __cplusplus
450454
}
451455
#endif /* __cplusplus */

cmake/system.cmake

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -640,21 +640,24 @@ endif()
640640
if (BUILD_BFLOAT16)
641641
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
642642
endif()
643+
if (BUILD_HFLOAT16)
644+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
645+
endif()
643646
if(NOT MSVC)
644647
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
645648
endif()
646649
# TODO: not sure what PFLAGS is -hpa
647650
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
648651
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
649652

650-
if ("${F_COMPILER}" STREQUAL "FLANG")
651-
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
652-
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
653-
endif ()
654-
endif ()
655-
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
656-
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
657-
endif ()
653+
if ("${F_COMPILER}" STREQUAL "FLANG")
654+
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
655+
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
656+
endif ()
657+
endif ()
658+
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
659+
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
660+
endif ()
658661
endif ()
659662

660663

0 commit comments

Comments
 (0)