@@ -33,11 +33,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
33
#include <arm_neon.h>
34
34
#include "common.h"
35
35
36
- static inline float bf16_to_fp32 (bfloat16 bf16 ) {
37
- uint32_t fp32 = (uint32_t )bf16 << 16 ;
38
- return * ((float * )& fp32 );
39
- }
40
-
41
36
int CNAME (BLASLONG m , BLASLONG n , float alpha , bfloat16 * a , BLASLONG lda , bfloat16 * x , BLASLONG incx , float beta , float * y , BLASLONG incy )
42
37
{
43
38
if (m < 1 || n < 1 ) return (0 );
@@ -132,10 +127,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
132
127
}
133
128
134
129
for (; i < m ; ++ i ) {
135
- y0_ptr [iy ] += alpha * a0_ptr [i ] * x_ptr [i ];
136
- y1_ptr [iy ] += alpha * a1_ptr [i ] * x_ptr [i ];
137
- y2_ptr [iy ] += alpha * a2_ptr [i ] * x_ptr [i ];
138
- y3_ptr [iy ] += alpha * a3_ptr [i ] * x_ptr [i ];
130
+ y0_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a0_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
131
+ y1_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a1_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
132
+ y2_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a2_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
133
+ y3_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a3_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
139
134
}
140
135
141
136
iy += incy ;
@@ -177,7 +172,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
177
172
}
178
173
179
174
for (; i < m ; ++ i ) {
180
- y_ptr [iy ] += alpha * a_ptr [i ] * x_ptr [i ];
175
+ y_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
181
176
}
182
177
183
178
iy += incy ;
@@ -191,7 +186,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
191
186
temp = 0.0 ;
192
187
ix = 0 ;
193
188
for (i = 0 ; i < m ; i ++ ) {
194
- temp += bf16_to_fp32 ( a [i ]) * bf16_to_fp32 ( x [ix ]);
189
+ temp += vcvtah_f32_bf16 ( a_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [ix ]);
195
190
ix += incx ;
196
191
}
197
192
if (beta == 0.0f ) {
0 commit comments