@@ -936,6 +936,7 @@ int printBuiltinAttributes(MlirContext ctx) {
936
936
int64_t ints64 [] = {0 , 1 };
937
937
float floats [] = {0.0f , 1.0f };
938
938
double doubles [] = {0.0 , 1.0 };
939
+ uint16_t bf16s [] = {0x0 , 0x3f80 };
939
940
MlirAttribute encoding = mlirAttributeGetNull ();
940
941
MlirAttribute boolElements = mlirDenseElementsAttrBoolGet (
941
942
mlirRankedTensorTypeGet (2 , shape , mlirIntegerTypeGet (ctx , 1 ), encoding ),
@@ -974,6 +975,9 @@ int printBuiltinAttributes(MlirContext ctx) {
974
975
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet (
975
976
mlirRankedTensorTypeGet (2 , shape , mlirF64TypeGet (ctx ), encoding ), 2 ,
976
977
doubles );
978
+ MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get (
979
+ mlirRankedTensorTypeGet (2 , shape , mlirBF16TypeGet (ctx ), encoding ), 2 ,
980
+ bf16s );
977
981
978
982
if (!mlirAttributeIsADenseElements (boolElements ) ||
979
983
!mlirAttributeIsADenseElements (uint8Elements ) ||
@@ -983,7 +987,8 @@ int printBuiltinAttributes(MlirContext ctx) {
983
987
!mlirAttributeIsADenseElements (uint64Elements ) ||
984
988
!mlirAttributeIsADenseElements (int64Elements ) ||
985
989
!mlirAttributeIsADenseElements (floatElements ) ||
986
- !mlirAttributeIsADenseElements (doubleElements ))
990
+ !mlirAttributeIsADenseElements (doubleElements ) ||
991
+ !mlirAttributeIsADenseElements (bf16Elements ))
987
992
return 14 ;
988
993
989
994
if (mlirDenseElementsAttrGetBoolValue (boolElements , 1 ) != 1 ||
@@ -1009,6 +1014,7 @@ int printBuiltinAttributes(MlirContext ctx) {
1009
1014
mlirAttributeDump (int64Elements );
1010
1015
mlirAttributeDump (floatElements );
1011
1016
mlirAttributeDump (doubleElements );
1017
+ mlirAttributeDump (bf16Elements );
1012
1018
// CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
1013
1019
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
1014
1020
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
@@ -1018,6 +1024,7 @@ int printBuiltinAttributes(MlirContext ctx) {
1018
1024
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
1019
1025
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
1020
1026
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
1027
+ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
1021
1028
1022
1029
MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet (
1023
1030
mlirRankedTensorTypeGet (2 , shape , mlirIntegerTypeGet (ctx , 1 ), encoding ),
@@ -1094,12 +1101,15 @@ int printBuiltinAttributes(MlirContext ctx) {
1094
1101
float * floatRawData = (float * )mlirDenseElementsAttrGetRawData (floatElements );
1095
1102
double * doubleRawData =
1096
1103
(double * )mlirDenseElementsAttrGetRawData (doubleElements );
1104
+ uint16_t * bf16RawData =
1105
+ (uint16_t * )mlirDenseElementsAttrGetRawData (bf16Elements );
1097
1106
if (uint8RawData [0 ] != 0u || uint8RawData [1 ] != 1u || int8RawData [0 ] != 0 ||
1098
1107
int8RawData [1 ] != 1 || uint32RawData [0 ] != 0u || uint32RawData [1 ] != 1u ||
1099
1108
int32RawData [0 ] != 0 || int32RawData [1 ] != 1 || uint64RawData [0 ] != 0u ||
1100
1109
uint64RawData [1 ] != 1u || int64RawData [0 ] != 0 || int64RawData [1 ] != 1 ||
1101
1110
floatRawData [0 ] != 0.0f || floatRawData [1 ] != 1.0f ||
1102
- doubleRawData [0 ] != 0.0 || doubleRawData [1 ] != 1.0 )
1111
+ doubleRawData [0 ] != 0.0 || doubleRawData [1 ] != 1.0 ||
1112
+ bf16RawData [0 ] != 0 || bf16RawData [1 ] != 0x3f80 )
1103
1113
return 18 ;
1104
1114
1105
1115
mlirAttributeDump (splatBool );
@@ -1123,8 +1133,10 @@ int printBuiltinAttributes(MlirContext ctx) {
1123
1133
1124
1134
mlirAttributeDump (mlirElementsAttrGetValue (floatElements , 2 , uints64 ));
1125
1135
mlirAttributeDump (mlirElementsAttrGetValue (doubleElements , 2 , uints64 ));
1136
+ mlirAttributeDump (mlirElementsAttrGetValue (bf16Elements , 2 , uints64 ));
1126
1137
// CHECK: 1.000000e+00 : f32
1127
1138
// CHECK: 1.000000e+00 : f64
1139
+ // CHECK: 1.000000e+00 : bf16
1128
1140
1129
1141
int64_t indices [] = {0 , 1 };
1130
1142
int64_t one = 1 ;
0 commit comments