Skip to content

Commit f708944

Browse files
committed
Add all 4 variations of the SBGEMM to compare_sgemm_sbgemm
1 parent cb15483 commit f708944

File tree

1 file changed

+58
-34
lines changed

1 file changed

+58
-34
lines changed

test/compare_sgemm_sbgemm.c

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,26 @@ main (int argc, char *argv[])
8686
{
8787
blasint m, n, k;
8888
int i, j, l;
89-
blasint x;
89+
blasint x, y;
9090
int ret = 0;
9191
int loop = 100;
9292
char transA = 'N', transB = 'N';
9393
float alpha = 1.0, beta = 0.0;
9494

9595
for (x = 0; x <= loop; x++)
96+
{
97+
for (y = 0; y < 4; y++)
9698
{
99+
if ((y == 0) || (y == 2)) {
100+
transA = 'N';
101+
} else {
102+
transA = 'T';
103+
}
104+
if ((y == 0) || (y == 1)) {
105+
transB = 'N';
106+
} else {
107+
transB = 'T';
108+
}
97109
m = k = n = x;
98110
float A[m * k];
99111
float B[k * n];
@@ -104,43 +116,55 @@ main (int argc, char *argv[])
104116
blasint one=1;
105117

106118
for (j = 0; j < m; j++)
107-
{
108-
for (i = 0; i < m; i++)
109-
{
110-
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
111-
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
112-
C[j * k + i] = 0;
113-
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
114-
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
115-
AA[j * k + i].v = atmp;
116-
BB[j * k + i].v = btmp;
117-
CC[j * k + i] = 0;
118-
DD[j * k + i] = 0;
119-
}
120-
}
119+
{
120+
for (i = 0; i < m; i++)
121+
{
122+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
123+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
124+
C[j * k + i] = 0;
125+
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
126+
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
127+
AA[j * k + i].v = atmp;
128+
BB[j * k + i].v = btmp;
129+
CC[j * k + i] = 0;
130+
DD[j * k + i] = 0;
131+
}
132+
}
121133
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
122-
&m, B, &k, &beta, C, &m);
134+
&m, B, &k, &beta, C, &m);
123135
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA,
124-
&m, (bfloat16*)BB, &k, &beta, CC, &m);
136+
&m, (bfloat16*)BB, &k, &beta, CC, &m);
137+
for (i = 0; i < n; i++)
138+
for (j = 0; j < m; j++)
139+
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
140+
ret++;
125141
for (i = 0; i < n; i++)
126-
for (j = 0; j < m; j++)
127-
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
128-
ret++;
129-
if (transA == 'N' && transB == 'N')
130-
{
131-
for (i = 0; i < n; i++)
132-
for (j = 0; j < m; j++)
133-
for (l = 0; l < k; l++)
134-
{
135-
DD[i * m + j] +=
136-
float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]);
137-
}
138-
for (i = 0; i < n; i++)
139-
for (j = 0; j < m; j++)
140-
if (CC[i * m + j] != DD[i * m + j])
141-
ret++;
142-
}
142+
for (j = 0; j < m; j++)
143+
for (l = 0; l < k; l++)
144+
if (transA == 'N' && transB == 'N')
145+
{
146+
DD[i * m + j] +=
147+
float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]);
148+
} else if (transA == 'T' && transB == 'N')
149+
{
150+
DD[i * m + j] +=
151+
float16to32 (AA[k * j + l]) * float16to32 (BB[l + k * i]);
152+
} else if (transA == 'N' && transB == 'T')
153+
{
154+
DD[i * m + j] +=
155+
float16to32 (AA[l * m + j]) * float16to32 (BB[i + l * n]);
156+
} else if (transA == 'T' && transB == 'T')
157+
{
158+
DD[i * m + j] +=
159+
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]);
160+
}
161+
for (i = 0; i < n; i++)
162+
for (j = 0; j < m; j++)
163+
if (CC[i * m + j] != DD[i * m + j])
164+
ret++;
143165
}
166+
}
167+
144168
if (ret != 0)
145169
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
146170
return ret;

0 commit comments

Comments
 (0)