Skip to content

Commit 0a96779

Browse files
committed
Add FP16 support for RISCV
1 parent 2996c25 commit 0a96779

12 files changed

+270
-35
lines changed

Makefile.system

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,7 @@ export TARGET_CORE
18891889
export NO_AVX512
18901890
export NO_AVX2
18911891
export BUILD_BFLOAT16
1892+
export BUILD_HFLOAT16
18921893
export NO_LSX
18931894
export NO_LASX
18941895

Makefile.tail

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
2+
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
23
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
34
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
45
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
1112

1213
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
1314

14-
BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
15-
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
15+
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
16+
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
1617

1718
ifdef EXPRECISION
1819
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
2425
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
2526
endif
2627

28+
$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
2729
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
2830
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
2931
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
3335
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
3436
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
3537

38+
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3639
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3740
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3841
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)

cblas.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
446446
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
447447
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
448448

449+
/*** FLOAT16 extensions */
450+
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
451+
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
452+
449453
#ifdef __cplusplus
450454
}
451455
#endif /* __cplusplus */

cmake/system.cmake

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -640,21 +640,24 @@ endif()
640640
if (BUILD_BFLOAT16)
641641
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
642642
endif()
643+
if (BUILD_HFLOAT16)
644+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
645+
endif()
643646
if(NOT MSVC)
644647
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
645648
endif()
646649
# TODO: not sure what PFLAGS is -hpa
647650
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
648651
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
649652

650-
if ("${F_COMPILER}" STREQUAL "FLANG")
651-
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
652-
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
653-
endif ()
654-
endif ()
655-
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
656-
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
657-
endif ()
653+
if ("${F_COMPILER}" STREQUAL "FLANG")
654+
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
655+
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
656+
endif ()
657+
endif ()
658+
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
659+
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
660+
endif ()
658661
endif ()
659662

660663

common.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ typedef uint16_t bfloat16;
266266
#define BFLOAT16CONVERSION 1
267267
#endif
268268

269+
#ifndef hfloat16
270+
#include <stdint.h>
271+
typedef uint16_t hfloat16;
272+
#endif
273+
269274
#ifdef USE64BITINT
270275
typedef BLASLONG blasint;
271276
#if defined(OS_WINDOWS) && defined(__64BIT__)
@@ -313,8 +318,8 @@ typedef int blasint;
313318
#define SIZE 2
314319
#define BASE_SHIFT 1
315320
#define ZBASE_SHIFT 2
316-
#elif defined(FLOAT16)
317-
#define IFLOAT float16
321+
#elif defined(HFLOAT16)
322+
#define IFLOAT hfloat16
318323
#define XFLOAT IFLOAT
319324
#define FLOAT float
320325
#define SIZE 2

common_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
481481

482482
/* Level 3 routines */
483483

484+
void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
485+
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
484486
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
485487
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
486488
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,

common_level3.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
5454

5555
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
5656

57-
57+
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
58+
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
5859
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
5960
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
6061
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
7879
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
7980
#endif
8081

82+
int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
83+
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
84+
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
85+
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
8186
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
8287
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
8388
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
505510
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
506511
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
507512

513+
int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
508514
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
509515
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
510516
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
657663
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
658664
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);
659665

666+
int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
667+
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
668+
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
669+
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
670+
660671
int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
661672
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
662673
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
754765
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
755766
#endif
756767

768+
int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
769+
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
770+
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
771+
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
772+
757773
int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
758774
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
759775
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
19441960
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
19451961
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
19461962
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
1963+
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
19471964

19481965
#ifdef __CUDACC__
19491966
}

common_macro.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#ifndef COMMON_MACRO
4040
#define COMMON_MACRO
4141

42+
#include "common_sh.h"
4243
#include "common_sb.h"
4344
#include "common_s.h"
4445
#include "common_d.h"
@@ -656,6 +657,50 @@
656657
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT
657658
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN
658659
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT
660+
#elif defined(HFLOAT16)
661+
#define GEMM_BETA SHGEMM_BETA
662+
#define GEMM_KERNEL_N SHGEMM_KERNEL
663+
#define GEMM_KERNEL_L SHGEMM_KERNEL
664+
#define GEMM_KERNEL_R SHGEMM_KERNEL
665+
#define GEMM_KERNEL_B SHGEMM_KERNEL
666+
#define GEMM_NN SHGEMM_NN
667+
#define GEMM_CN SHGEMM_TN
668+
#define GEMM_TN SHGEMM_TN
669+
#define GEMM_NC SHGEMM_NT
670+
#define GEMM_NT SHGEMM_NT
671+
#define GEMM_CC SHGEMM_TT
672+
#define GEMM_CT SHGEMM_TT
673+
#define GEMM_TC SHGEMM_TT
674+
#define GEMM_TT SHGEMM_TT
675+
#define GEMM_NR SHGEMM_NN
676+
#define GEMM_TR SHGEMM_TN
677+
#define GEMM_CR SHGEMM_TN
678+
#define GEMM_RN SHGEMM_NN
679+
#define GEMM_RT SHGEMM_NT
680+
#define GEMM_RC SHGEMM_NT
681+
#define GEMM_RR SHGEMM_NN
682+
#define GEMM_ONCOPY SHGEMM_ONCOPY
683+
#define GEMM_OTCOPY SHGEMM_OTCOPY
684+
#define GEMM_INCOPY SHGEMM_INCOPY
685+
#define GEMM_ITCOPY SHGEMM_ITCOPY
686+
687+
#define GEMM_THREAD_NN SHGEMM_THREAD_NN
688+
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
689+
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
690+
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
691+
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
692+
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
693+
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
694+
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
695+
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
696+
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
697+
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
698+
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
699+
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
700+
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
701+
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
702+
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
703+
659704

660705
#elif defined(BFLOAT16)
661706

0 commit comments

Comments
 (0)