Skip to content

Commit 72564cf

Browse files
committed
feat: implement float support
1 parent c8e0894 commit 72564cf

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

GraphBLAS/FactoryKernels/GB_AxB__plus_times_fp32.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
#include "assign/GB_bitmap_assign_methods.h"
1818
#include "FactoryKernels/GB_AxB__include2.h"
1919

20+
// riscv intrinsics
21+
22+
#define VSETVL(x) __riscv_vsetvl_e32m8(x)
23+
#define VLE(x,y) __riscv_vle32_v_f32m8(x, y)
24+
#define VFMACC(x,y,z,w) __riscv_vfmacc_vf_f32m8(x, y, z, w)
25+
#define VSE(x,y,z) __riscv_vse32_v_f32m8(x, y, z)
26+
#define VECTORTYPE vfloat32m8_t
27+
2028
// semiring operators:
2129
#define GB_MULTADD(z,a,b,i,k,j) z += (a*b)
2230
#define GB_MULT(z,a,b,i,k,j) z = (a*b)

GraphBLAS/FactoryKernels/GB_AxB__plus_times_fp64.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
#include "assign/GB_bitmap_assign_methods.h"
1818
#include "FactoryKernels/GB_AxB__include2.h"
1919

20+
// riscv intrinsics
21+
22+
#define VSETVL(x) __riscv_vsetvl_e64m8(x)
23+
#define VLE(x,y) __riscv_vle64_v_f64m8(x, y)
24+
#define VFMACC(x,y,z,w) __riscv_vfmacc_vf_f64m8(x, y, z, w)
25+
#define VSE(x,y,z) __riscv_vse64_v_f64m8(x, y, z)
26+
#define VECTORTYPE vfloat64m8_t
27+
2028
// semiring operators:
2129
#define GB_MULTADD(z,a,b,i,k,j) z += (a*b)
2230
#define GB_MULT(z,a,b,i,k,j) z = (a*b)
@@ -289,7 +297,6 @@ GrB_Info GB (_Asaxpy4B__plus_times_fp64)
289297
//----------------------------------------------------------------------
290298
// saxpy5 method with RISC-V vectors
291299
//----------------------------------------------------------------------
292-
293300
#if GB_COMPILER_SUPPORTS_RVV1
294301

295302
GB_TARGET_RVV1 static inline void GB_AxB_saxpy5_unrolled_rvv

GraphBLAS/Source/mxm/template/GB_AxB_saxpy5_lv.c

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
const int64_t *restrict Bi = B->i;
66
const GB_A_TYPE *restrict Ax = (GB_A_TYPE *)A->x;
77
const GB_B_TYPE *restrict Bx = (GB_B_TYPE *)B->x;
8-
size_t vl = __riscv_vsetvl_e64m8(m);
8+
size_t vl = VSETVL(m);
99
GB_C_TYPE *restrict Cx = (GB_C_TYPE *)C->x;
1010

1111
#pragma omp parallel for num_threads(nthreads) schedule(dynamic, 1)
@@ -22,33 +22,31 @@
2222
const int64_t pB_end = Bp[jB + 1];
2323
for (int64_t i = 0; i < m && (m - i) >= vl; i += vl)
2424
{
25-
vfloat64m8_t vc = __riscv_vle64_v_f64m8(Cxj + i, vl);
26-
25+
VECTORTYPE vc = VLE(Cxj + i, vl);
2726
for (int64_t pB = pB_start; pB < pB_end; pB++)
2827
{
2928
const int64_t k = Bi[pB];
3029
const GB_B_TYPE bkj = Bx[pB];
31-
vfloat64m8_t va = __riscv_vle64_v_f64m8(Ax + i + k * m, vl);
32-
vc = __riscv_vfmacc_vf_f64m8(vc, bkj, va, vl);
30+
VECTORTYPE va = VLE(Ax + i + k * m, vl);
31+
vc = VFMACC(vc, bkj, va, vl);
3332
}
3433

35-
__riscv_vse64_v_f64m8(Cxj + i, vc, vl);
34+
VSE(Cxj + i, vc, vl);
3635
}
3736
int64_t remaining = m % vl;
3837
if (remaining > 0)
3938
{
4039
int64_t i = m - remaining;
41-
vfloat64m8_t vc = __riscv_vle64_v_f64m8(Cxj + i, remaining);
42-
40+
VECTORTYPE vc = VLE(Cxj + i, remaining);
4341
for (int64_t pB = pB_start; pB < pB_end; pB++)
4442
{
4543
const int64_t k = Bi[pB];
4644
const GB_B_TYPE bkj = Bx[pB];
47-
vfloat64m8_t va = __riscv_vle64_v_f64m8(Ax + i + k * m, remaining);
48-
vc = __riscv_vfmacc_vf_f64m8(vc, bkj, va, remaining);
45+
VECTORTYPE va = VLE(Ax + i + k * m, remaining);
46+
vc = VFMACC(vc, bkj, va, remaining);
4947
}
5048

51-
__riscv_vse64_v_f64m8(Cxj + i, vc, remaining);
49+
VSE(Cxj + i, vc, remaining);
5250
}
5351
}
5452
}

0 commit comments

Comments
 (0)