Skip to content

Commit 0ec3e06

Browse files
Mousiustaoye9
andcommitted
Use correct constants for per-target BGEMM/SBGEMM
This fixes the build and tests on `NEOVERSEV1` target, which was failing with specific constants for `SBGEMM` Co-authored-by: Ye Tao <ye.tao@arm.com>
1 parent 9dfd48c commit 0ec3e06

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

driver/level3/level3.c

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,22 @@
169169
#define STOP_RPCC(COUNTER)
170170
#endif
171171

172+
#if defined(BUILD_BFLOAT16)
173+
#if defined(DYNAMIC_ARCH)
174+
#if defined(BGEMM)
175+
#define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k
176+
#else
177+
#define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k
178+
#endif
179+
#else
180+
#if defined(BGEMM)
181+
#define BFLOAT16_ALIGN_K BGEMM_ALIGN_K
182+
#else
183+
#define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K
184+
#endif
185+
#endif
186+
#endif
187+
172188
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
173189
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
174190
BLASLONG k, lda, ldb, ldc;
@@ -305,12 +321,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
305321
}
306322

307323
BLASLONG pad_min_l = min_l;
308-
#if defined(HALF)
309-
#if defined(DYNAMIC_ARCH)
310-
pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1);
311-
#else
312-
pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);;
313-
#endif
324+
#if defined(BFLOAT16)
325+
pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1);
314326
#endif
315327

316328
/* First, we have to move data A to L2 cache */

driver/level3/level3_thread.c

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,22 @@ typedef struct {
216216
#define STOP_RPCC(COUNTER)
217217
#endif
218218

219+
#if defined(BUILD_BFLOAT16)
220+
#if defined(DYNAMIC_ARCH)
221+
#if defined(BGEMM)
222+
#define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k
223+
#else
224+
#define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k
225+
#endif
226+
#else
227+
#if defined(BGEMM)
228+
#define BFLOAT16_ALIGN_K BGEMM_ALIGN_K
229+
#else
230+
#define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K
231+
#endif
232+
#endif
233+
#endif
234+
219235
static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
220236

221237
IFLOAT *buffer[DIVIDE_RATE];
@@ -324,12 +340,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
324340

325341
BLASLONG pad_min_l = min_l;
326342

327-
#if defined(HALF)
328-
#if defined(DYNAMIC_ARCH)
329-
pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1);
330-
#else
331-
pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);;
332-
#endif
343+
#if defined(BFLOAT16)
344+
pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1);
333345
#endif
334346

335347
/* Determine step size in m

0 commit comments

Comments
 (0)