@@ -81,6 +81,16 @@ float16to32 (bfloat16_bits f16)
81
81
return f32 .v ;
82
82
}
83
83
84
+ float
85
+ float32to16 (float32_bits f32 )
86
+ {
87
+ bfloat16_bits f16 ;
88
+ f16 .bits .s = f32 .bits .s ;
89
+ f16 .bits .e = f32 .bits .e ;
90
+ f16 .bits .m = (uint32_t ) f32 .bits .m >> 16 ;
91
+ return f32 .v ;
92
+ }
93
+
84
94
int
85
95
main (int argc , char * argv [])
86
96
{
@@ -108,16 +118,16 @@ main (int argc, char *argv[])
108
118
A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
109
119
B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
110
120
C [j * k + i ] = 0 ;
111
- AA [j * k + i ].v = * ( uint32_t * ) & A [j * k + i ] >> 16 ;
112
- BB [j * k + i ].v = * ( uint32_t * ) & B [j * k + i ] >> 16 ;
121
+ AA [j * k + i ].v = float32to16 ( A [j * k + i ] ) ;
122
+ BB [j * k + i ].v = float32to16 ( B [j * k + i ] ) ;
113
123
CC [j * k + i ] = 0 ;
114
124
DD [j * k + i ] = 0 ;
115
125
}
116
126
}
117
127
SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
118
128
& m , B , & k , & beta , C , & m );
119
- SBGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
120
- & m , BB , & k , & beta , CC , & m );
129
+ SBGEMM (& transA , & transB , & m , & n , & k , & alpha , ( bfloat16 * ) AA ,
130
+ & m , ( bfloat16 * ) BB , & k , & beta , CC , & m );
121
131
for (i = 0 ; i < n ; i ++ )
122
132
for (j = 0 ; j < m ; j ++ )
123
133
if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
0 commit comments