Skip to content

Commit 38d7a7b

Browse files
authored
Fix ?GEMMT
1 parent 4eac244 commit 38d7a7b

File tree

1 file changed

+50
-50
lines changed

1 file changed

+50
-50
lines changed

interface/gemmt.c

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,26 @@
3535
#include <stdio.h>
3636
#include <stdlib.h>
3737
#include "common.h"
38-
#ifdef FUNCTION_PROFILE
39-
#include "functable.h"
40-
#endif
4138

4239
#ifndef COMPLEX
4340
#define SMP_THRESHOLD_MIN 65536.0
4441
#ifdef XDOUBLE
45-
#define ERROR_NAME "QGEMT "
42+
#define ERROR_NAME "QGEMMT "
4643
#elif defined(DOUBLE)
47-
#define ERROR_NAME "DGEMT "
44+
#define ERROR_NAME "DGEMMT "
4845
#elif defined(BFLOAT16)
49-
#define ERROR_NAME "SBGEMT "
46+
#define ERROR_NAME "SBGEMMT "
5047
#else
51-
#define ERROR_NAME "SGEMT "
48+
#define ERROR_NAME "SGEMMT "
5249
#endif
5350
#else
5451
#define SMP_THRESHOLD_MIN 8192.0
5552
#ifdef XDOUBLE
56-
#define ERROR_NAME "XGEMT "
53+
#define ERROR_NAME "XGEMMT "
5754
#elif defined(DOUBLE)
58-
#define ERROR_NAME "ZGEMT "
55+
#define ERROR_NAME "ZGEMMT "
5956
#else
60-
#define ERROR_NAME "CGEMT "
57+
#define ERROR_NAME "CGEMMT "
6158
#endif
6259
#endif
6360

@@ -68,13 +65,13 @@
6865
#ifndef CBLAS
6966

7067
void NAME(char *UPLO, char *TRANSA, char *TRANSB,
71-
blasint * M, blasint * N, blasint * K,
68+
blasint * M, blasint * K,
7269
FLOAT * Alpha,
7370
IFLOAT * a, blasint * ldA,
7471
IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
7572
{
7673

77-
blasint m, n, k;
74+
blasint m, k;
7875
blasint lda, ldb, ldc;
7976
int transa, transb, uplo;
8077
blasint info;
@@ -92,7 +89,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
9289
PRINT_DEBUG_NAME;
9390

9491
m = *M;
95-
n = *N;
9692
k = *K;
9793

9894
#if defined(COMPLEX)
@@ -167,8 +163,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
167163
info = 13;
168164
if (k < 0)
169165
info = 5;
170-
if (n < 0)
171-
info = 4;
172166
if (m < 0)
173167
info = 3;
174168
if (transb < 0)
@@ -184,7 +178,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
184178

185179
void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
186180
enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M,
187-
blasint N, blasint k,
181+
blasint k,
188182
#ifndef COMPLEX
189183
FLOAT alpha,
190184
IFLOAT * A, blasint LDA,
@@ -205,7 +199,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
205199

206200
int transa, transb, uplo;
207201
blasint info;
208-
blasint m, n, lda, ldb;
202+
blasint m, lda, ldb;
209203
FLOAT *a, *b;
210204
XFLOAT *buffer;
211205

@@ -248,9 +242,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
248242
transb = 3;
249243
#endif
250244

251-
m = M;
252-
n = N;
253-
254245
a = (void *)A;
255246
b = (void *)B;
256247
lda = LDA;
@@ -262,8 +253,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
262253
info = 13;
263254
if (k < 0)
264255
info = 5;
265-
if (n < 0)
266-
info = 4;
267256
if (m < 0)
268257
info = 3;
269258
if (transb < 0)
@@ -273,8 +262,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
273262
}
274263

275264
if (order == CblasRowMajor) {
276-
m = N;
277-
n = M;
278265

279266
a = (void *)B;
280267
b = (void *)A;
@@ -319,8 +306,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
319306
info = 13;
320307
if (k < 0)
321308
info = 5;
322-
if (n < 0)
323-
info = 4;
324309
if (m < 0)
325310
info = 3;
326311
if (transb < 0)
@@ -407,37 +392,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
407392

408393
#endif
409394

410-
if ((m == 0) || (n == 0))
395+
if ((m == 0) )
411396
return;
412397

413398
IDEBUG_START;
414399

415-
FUNCTION_PROFILE_START();
416-
417400
const blasint incb = (transb == 0) ? 1 : ldb;
418401

419402
if (uplo == 1) {
420-
for (i = 0; i < n; i++) {
421-
j = n - i;
403+
for (i = 0; i < m; i++) {
404+
j = m - i;
422405

423406
l = j;
424407
#if defined(COMPLEX)
425408
aa = a + i * 2;
426409
bb = b + i * ldb * 2;
427410
if (transa) {
428-
l = k;
429411
aa = a + lda * i * 2;
430-
bb = b + i * 2;
431412
}
413+
if (transb)
414+
bb = b + i * 2;
432415
cc = c + i * 2 * ldc + i * 2;
433416
#else
434417
aa = a + i;
435418
bb = b + i * ldb;
436419
if (transa) {
437-
l = k;
438420
aa = a + lda * i;
439-
bb = b + i;
440421
}
422+
if (transb)
423+
bb = b + i;
441424
cc = c + i * ldc + i;
442425
#endif
443426

@@ -458,8 +441,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
458441

459442
IDEBUG_START;
460443

461-
FUNCTION_PROFILE_START();
462-
463444
buffer_size = j + k + 128 / sizeof(FLOAT);
464445
#ifdef WINDOWS_ABI
465446
buffer_size += 160 / sizeof(FLOAT);
@@ -479,20 +460,34 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
479460
#endif
480461

481462
#if defined(COMPLEX)
463+
if (!transa)
482464
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
483465
aa, lda, bb, incb, cc, 1,
484466
buffer);
467+
else
468+
(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i,
469+
aa, lda, bb, incb, cc, 1,
470+
buffer);
485471
#else
472+
if (!transa)
486473
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
487474
bb, incb, cc, 1, buffer);
475+
else
476+
(gemv[(int)transa]) (k, j, 0, alpha, aa, lda,
477+
bb, incb, cc, 1, buffer);
488478
#endif
489479
#ifdef SMP
490480
} else {
491-
481+
if (!transa)
492482
(gemv_thread[(int)transa]) (j, k, alpha, aa,
493483
lda, bb, incb, cc,
494484
1, buffer,
495485
nthreads);
486+
else
487+
(gemv_thread[(int)transa]) (k, j, alpha, aa,
488+
lda, bb, incb, cc,
489+
1, buffer,
490+
nthreads);
496491

497492
}
498493
#endif
@@ -501,21 +496,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
501496
}
502497
} else {
503498

504-
for (i = 0; i < n; i++) {
499+
for (i = 0; i < m; i++) {
505500
j = i + 1;
506501

507502
l = j;
508503
#if defined COMPLEX
509504
bb = b + i * ldb * 2;
510-
if (transa) {
511-
l = k;
505+
if (transb) {
512506
bb = b + i * 2;
513507
}
514508
cc = c + i * 2 * ldc;
515509
#else
516510
bb = b + i * ldb;
517-
if (transa) {
518-
l = k;
511+
if (transb) {
519512
bb = b + i;
520513
}
521514
cc = c + i * ldc;
@@ -537,8 +530,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
537530
#endif
538531
IDEBUG_START;
539532

540-
FUNCTION_PROFILE_START();
541-
542533
buffer_size = j + k + 128 / sizeof(FLOAT);
543534
#ifdef WINDOWS_ABI
544535
buffer_size += 160 / sizeof(FLOAT);
@@ -558,30 +549,39 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
558549
#endif
559550

560551
#if defined(COMPLEX)
552+
if (!transa)
561553
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
562554
a, lda, bb, incb, cc, 1,
563555
buffer);
556+
else
557+
(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i,
558+
a, lda, bb, incb, cc, 1,
559+
buffer);
564560
#else
561+
if (!transa)
565562
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
566563
incb, cc, 1, buffer);
564+
else
565+
(gemv[(int)transa]) (k, j, 0, alpha, a, lda, bb,
566+
incb, cc, 1, buffer);
567567
#endif
568568

569569
#ifdef SMP
570570
} else {
571-
571+
if (!transa)
572572
(gemv_thread[(int)transa]) (j, k, alpha, a, lda,
573573
bb, incb, cc, 1,
574574
buffer, nthreads);
575-
575+
else
576+
(gemv_thread[(int)transa]) (k, j, alpha, a, lda,
577+
bb, incb, cc, 1,
578+
buffer, nthreads);
576579
}
577580
#endif
578581

579582
STACK_FREE(buffer);
580583
}
581584
}
582-
FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE,
583-
args.m * args.k + args.k * args.n +
584-
args.m * args.n, 2 * args.m * args.n * args.k);
585585

586586
IDEBUG_END;
587587

0 commit comments

Comments
 (0)