Skip to content

Commit fa2b08b

Browse files
authored
Merge pull request #1 from gkdddd/riscv_shgemm
Added shgemm_kernel_8x8 for RISCV64_ZVL128B and shgemm_kernel_16x8 fo…
2 parents 0a96779 + 670ec6f commit fa2b08b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+39881
-611
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ endif ()
152152
if (NOT DEFINED BUILD_BFLOAT16)
153153
set (BUILD_BFLOAT16 false)
154154
endif ()
155+
if (NOT DEFINED BUILD_HFLOAT16)
156+
set (BUILD_HFLOAT16 false)
157+
endif ()
155158
# set which float types we want to build for
156159
if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16)
157160
# if none are defined, build for all

Makefile.prebuild

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ TARGET_FLAGS = -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d
6464
endif
6565

6666
ifeq ($(TARGET), RISCV64_ZVL256B)
67-
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
67+
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
6868
endif
6969

7070
ifeq ($(TARGET), RISCV64_ZVL128B)
71-
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
71+
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
7272
endif
7373

7474
ifeq ($(TARGET), RISCV64_GENERIC)

Makefile.riscv64

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ CCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh_zvl512b -mabi=lp64d
77
FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static
88
endif
99
ifeq ($(CORE), RISCV64_ZVL256B)
10-
CCOMMON_OPT += -march=rv64imafdcv_zvl256b -mabi=lp64d
11-
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
10+
CCOMMON_OPT += -march=rv64imafdcv_zvl256b_zvfh_zfh -mabi=lp64d
11+
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
1212
endif
1313
ifeq ($(CORE), RISCV64_ZVL128B)
14-
CCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
15-
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
14+
CCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
15+
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
1616
endif
1717
ifeq ($(CORE), RISCV64_GENERIC)
1818
CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d

Makefile.rule

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ COMMON_PROF = -pg
308308
# If you want to enable the experimental BFLOAT16 support
309309
# BUILD_BFLOAT16 = 1
310310

311+
# If you want to enable the experimental HFLOAT16 support
312+
BUILD_HFLOAT16 = 1
311313

312314
# Set the thread number threshold beyond which the job array for the threaded level3 BLAS
313315
# will be allocated on the heap rather than the stack. (This array alone requires

Makefile.system

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ GEMM_GEMV_FORWARD_BF16 = 1
280280
endif
281281
ifeq ($(ARCH), riscv)
282282
GEMM_GEMV_FORWARD = 1
283+
BUILD_HFLOAT16 = 1
283284
endif
284285
ifeq ($(ARCH), power)
285286
GEMM_GEMV_FORWARD = 1
@@ -1547,6 +1548,9 @@ endif
15471548
ifeq ($(BUILD_BFLOAT16), 1)
15481549
CCOMMON_OPT += -DBUILD_BFLOAT16
15491550
endif
1551+
ifeq ($(BUILD_HFLOAT16), 1)
1552+
CCOMMON_OPT += -DBUILD_HFLOAT16
1553+
endif
15501554
ifeq ($(BUILD_SINGLE), 1)
15511555
CCOMMON_OPT += -DBUILD_SINGLE=1
15521556
endif

benchmark/gemm.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3535
#define GEMM BLASFUNC(dgemm)
3636
#elif defined(HALF)
3737
#define GEMM BLASFUNC(sbgemm)
38+
#elif defined(HFLOAT16)
39+
#define GEMM BLASFUNC(shgemm)
3840
#else
3941
#define GEMM BLASFUNC(sgemm)
4042
#endif

cblas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
446446
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
447447
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
448448

449-
/*** FLOAT16 extensions */
449+
/*** FLOAT16 extensions ***/
450450
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
451451
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
452452

common.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,12 @@ typedef uint16_t bfloat16;
266266
#define BFLOAT16CONVERSION 1
267267
#endif
268268

269-
#ifndef hfloat16
270-
#include <stdint.h>
271-
typedef uint16_t hfloat16;
269+
#ifdef BUILD_HFLOAT16
270+
#ifndef hfloat16
271+
typedef _Float16 hfloat16;
272+
#endif
273+
#else
274+
typedef uint16_t hfloat16;
272275
#endif
273276

274277
#ifdef USE64BITINT

driver/level3/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ foreach (GEMM_DEFINE ${GEMM_DEFINES})
1818
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16")
1919
endif ()
2020
endif ()
21+
if (BUILD_HFLOAT16)
22+
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE}" "gemm_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16")
23+
if (USE_THREAD AND NOT USE_SIMPLE_THREADED_LEVEL3)
24+
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16")
25+
endif ()
26+
endif ()
2127
endforeach ()
2228

2329
if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)

driver/level3/Makefile

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ ifeq ($(BUILD_BFLOAT16),1)
2323
SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX)
2424
endif
2525

26+
ifeq ($(BUILD_HFLOAT16),1)
27+
SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
28+
endif
29+
2630
SBLASOBJS += \
2731
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
2832
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
@@ -210,6 +214,9 @@ ifneq ($(USE_SIMPLE_THREADED_LEVEL3), 1)
210214
ifeq ($(BUILD_BFLOAT16),1)
211215
SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX)
212216
endif
217+
ifeq ($(BUILD_HFLOAT16),1)
218+
SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
219+
endif
213220
SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
214221
DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
215222
QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
@@ -355,6 +362,18 @@ sbgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
355362
sbgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
356363
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
357364

365+
shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
366+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
367+
368+
shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
369+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
370+
371+
shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
372+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
373+
374+
shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
375+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
376+
358377
sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
359378
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
360379

@@ -562,6 +581,18 @@ sbgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
562581
sbgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
563582
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
564583

584+
shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
585+
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
586+
587+
shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
588+
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
589+
590+
shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
591+
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
592+
593+
shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
594+
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
595+
565596
sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
566597
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
567598

@@ -2747,6 +2778,18 @@ sbgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
27472778
sbgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
27482779
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
27492780

2781+
shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
2782+
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
2783+
2784+
shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
2785+
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
2786+
2787+
shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
2788+
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
2789+
2790+
shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
2791+
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
2792+
27502793
sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
27512794
$(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
27522795

@@ -2970,6 +3013,18 @@ sbgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
29703013
sbgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
29713014
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
29723015

3016+
shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
3017+
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
3018+
3019+
shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
3020+
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
3021+
3022+
shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
3023+
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
3024+
3025+
shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
3026+
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
3027+
29733028
sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
29743029
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
29753030

0 commit comments

Comments
 (0)