@@ -86,14 +86,26 @@ main (int argc, char *argv[])
86
86
{
87
87
blasint m , n , k ;
88
88
int i , j , l ;
89
- blasint x ;
89
+ blasint x , y ;
90
90
int ret = 0 ;
91
91
int loop = 100 ;
92
92
char transA = 'N' , transB = 'N' ;
93
93
float alpha = 1.0 , beta = 0.0 ;
94
94
95
95
for (x = 0 ; x <= loop ; x ++ )
96
+ {
97
+ for (y = 0 ; y < 4 ; y ++ )
96
98
{
99
+ if ((y == 0 ) || (y == 2 )) {
100
+ transA = 'N' ;
101
+ } else {
102
+ transA = 'T' ;
103
+ }
104
+ if ((y == 0 ) || (y == 1 )) {
105
+ transB = 'N' ;
106
+ } else {
107
+ transB = 'T' ;
108
+ }
97
109
m = k = n = x ;
98
110
float A [m * k ];
99
111
float B [k * n ];
@@ -104,43 +116,55 @@ main (int argc, char *argv[])
104
116
blasint one = 1 ;
105
117
106
118
for (j = 0 ; j < m ; j ++ )
107
- {
108
- for (i = 0 ; i < m ; i ++ )
109
- {
110
- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
111
- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
112
- C [j * k + i ] = 0 ;
113
- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
114
- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
115
- AA [j * k + i ].v = atmp ;
116
- BB [j * k + i ].v = btmp ;
117
- CC [j * k + i ] = 0 ;
118
- DD [j * k + i ] = 0 ;
119
- }
120
- }
119
+ {
120
+ for (i = 0 ; i < m ; i ++ )
121
+ {
122
+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
123
+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
124
+ C [j * k + i ] = 0 ;
125
+ sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
126
+ sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
127
+ AA [j * k + i ].v = atmp ;
128
+ BB [j * k + i ].v = btmp ;
129
+ CC [j * k + i ] = 0 ;
130
+ DD [j * k + i ] = 0 ;
131
+ }
132
+ }
121
133
SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
122
- & m , B , & k , & beta , C , & m );
134
+ & m , B , & k , & beta , C , & m );
123
135
SBGEMM (& transA , & transB , & m , & n , & k , & alpha , (bfloat16 * ) AA ,
124
- & m , (bfloat16 * )BB , & k , & beta , CC , & m );
136
+ & m , (bfloat16 * )BB , & k , & beta , CC , & m );
137
+ for (i = 0 ; i < n ; i ++ )
138
+ for (j = 0 ; j < m ; j ++ )
139
+ if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
140
+ ret ++ ;
125
141
for (i = 0 ; i < n ; i ++ )
126
- for (j = 0 ; j < m ; j ++ )
127
- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
128
- ret ++ ;
129
- if (transA == 'N' && transB == 'N' )
130
- {
131
- for (i = 0 ; i < n ; i ++ )
132
- for (j = 0 ; j < m ; j ++ )
133
- for (l = 0 ; l < k ; l ++ )
134
- {
135
- DD [i * m + j ] +=
136
- float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
137
- }
138
- for (i = 0 ; i < n ; i ++ )
139
- for (j = 0 ; j < m ; j ++ )
140
- if (CC [i * m + j ] != DD [i * m + j ])
141
- ret ++ ;
142
- }
142
+ for (j = 0 ; j < m ; j ++ )
143
+ for (l = 0 ; l < k ; l ++ )
144
+ if (transA == 'N' && transB == 'N' )
145
+ {
146
+ DD [i * m + j ] +=
147
+ float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
148
+ } else if (transA == 'T' && transB == 'N' )
149
+ {
150
+ DD [i * m + j ] +=
151
+ float16to32 (AA [k * j + l ]) * float16to32 (BB [l + k * i ]);
152
+ } else if (transA == 'N' && transB == 'T' )
153
+ {
154
+ DD [i * m + j ] +=
155
+ float16to32 (AA [l * m + j ]) * float16to32 (BB [i + l * n ]);
156
+ } else if (transA == 'T' && transB == 'T' )
157
+ {
158
+ DD [i * m + j ] +=
159
+ float16to32 (AA [k * j + l ]) * float16to32 (BB [i + l * n ]);
160
+ }
161
+ for (i = 0 ; i < n ; i ++ )
162
+ for (j = 0 ; j < m ; j ++ )
163
+ if (CC [i * m + j ] != DD [i * m + j ])
164
+ ret ++ ;
143
165
}
166
+ }
167
+
144
168
if (ret != 0 )
145
169
fprintf (stderr , "FATAL ERROR SBGEMM - Return code: %d\n" , ret );
146
170
return ret ;
0 commit comments