Skip to content

Commit 25c218b

Browse files
ashaysilvasean
authored andcommitted
[MLIR] Add function to create BFloat16 array attribute
This patch adds a new function `mlirDenseElementsAttrBFloat16Get()`, which accepts the shaped type, the number of BFloat16 values, and a pointer to an array of BFloat16 values, each of which is a `uint16_t` value. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D123981
1 parent 0f8c626 commit 25c218b

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloatGet(
379379
MlirType shapedType, intptr_t numElements, const float *elements);
380380
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet(
381381
MlirType shapedType, intptr_t numElements, const double *elements);
382+
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get(
383+
MlirType shapedType, intptr_t numElements, const uint16_t *elements);
382384

383385
/// Creates a dense elements attribute with the given shaped type from string
384386
/// elements.

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,13 @@ MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
474474
const double *elements) {
475475
return getDenseAttribute(shapedType, numElements, elements);
476476
}
477+
MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
478+
intptr_t numElements,
479+
const uint16_t *elements) {
480+
size_t bufferSize = numElements * 2;
481+
const void *buffer = static_cast<const void *>(elements);
482+
return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
483+
}
477484

478485
MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
479486
intptr_t numElements,

mlir/test/CAPI/ir.c

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ int printBuiltinAttributes(MlirContext ctx) {
936936
int64_t ints64[] = {0, 1};
937937
float floats[] = {0.0f, 1.0f};
938938
double doubles[] = {0.0, 1.0};
939+
uint16_t bf16s[] = {0x0, 0x3f80};
939940
MlirAttribute encoding = mlirAttributeGetNull();
940941
MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
941942
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
@@ -974,6 +975,9 @@ int printBuiltinAttributes(MlirContext ctx) {
974975
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
975976
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
976977
doubles);
978+
MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get(
979+
mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2,
980+
bf16s);
977981

978982
if (!mlirAttributeIsADenseElements(boolElements) ||
979983
!mlirAttributeIsADenseElements(uint8Elements) ||
@@ -983,7 +987,8 @@ int printBuiltinAttributes(MlirContext ctx) {
983987
!mlirAttributeIsADenseElements(uint64Elements) ||
984988
!mlirAttributeIsADenseElements(int64Elements) ||
985989
!mlirAttributeIsADenseElements(floatElements) ||
986-
!mlirAttributeIsADenseElements(doubleElements))
990+
!mlirAttributeIsADenseElements(doubleElements) ||
991+
!mlirAttributeIsADenseElements(bf16Elements))
987992
return 14;
988993

989994
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
@@ -1009,6 +1014,7 @@ int printBuiltinAttributes(MlirContext ctx) {
10091014
mlirAttributeDump(int64Elements);
10101015
mlirAttributeDump(floatElements);
10111016
mlirAttributeDump(doubleElements);
1017+
mlirAttributeDump(bf16Elements);
10121018
// CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
10131019
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
10141020
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
@@ -1018,6 +1024,7 @@ int printBuiltinAttributes(MlirContext ctx) {
10181024
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
10191025
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
10201026
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
1027+
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
10211028

10221029
MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
10231030
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
@@ -1094,12 +1101,15 @@ int printBuiltinAttributes(MlirContext ctx) {
10941101
float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
10951102
double *doubleRawData =
10961103
(double *)mlirDenseElementsAttrGetRawData(doubleElements);
1104+
uint16_t *bf16RawData =
1105+
(uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements);
10971106
if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
10981107
int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
10991108
int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
11001109
uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
11011110
floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
1102-
doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
1111+
doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 ||
1112+
bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80)
11031113
return 18;
11041114

11051115
mlirAttributeDump(splatBool);
@@ -1123,8 +1133,10 @@ int printBuiltinAttributes(MlirContext ctx) {
11231133

11241134
mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
11251135
mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
1136+
mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64));
11261137
// CHECK: 1.000000e+00 : f32
11271138
// CHECK: 1.000000e+00 : f64
1139+
// CHECK: 1.000000e+00 : bf16
11281140

11291141
int64_t indices[] = {0, 1};
11301142
int64_t one = 1;

0 commit comments

Comments
 (0)