Skip to content

Commit 308d8b8

Browse files
Rahul Kayaithjoker-eph
authored andcommitted
[mlir][python] 8b/16b DenseIntElements access
This extends dense attribute element access to support 8b and 16b ints. Also extends the corresponding parts of the C api. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117731
1 parent 26167ca commit 308d8b8

File tree

5 files changed

+91
-0
lines changed

5 files changed

+91
-0
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
355355
MlirType shapedType, intptr_t numElements, const uint8_t *elements);
356356
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
357357
MlirType shapedType, intptr_t numElements, const int8_t *elements);
358+
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get(
359+
MlirType shapedType, intptr_t numElements, const uint16_t *elements);
360+
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get(
361+
MlirType shapedType, intptr_t numElements, const int16_t *elements);
358362
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
359363
MlirType shapedType, intptr_t numElements, const uint32_t *elements);
360364
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
@@ -416,6 +420,10 @@ MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
416420
intptr_t pos);
417421
MLIR_CAPI_EXPORTED uint8_t
418422
mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
423+
MLIR_CAPI_EXPORTED int16_t
424+
mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos);
425+
MLIR_CAPI_EXPORTED uint16_t
426+
mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos);
419427
MLIR_CAPI_EXPORTED int32_t
420428
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
421429
MLIR_CAPI_EXPORTED uint32_t

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,12 @@ class PyDenseIntElementsAttribute
673673
if (width == 1) {
674674
return mlirDenseElementsAttrGetBoolValue(*this, pos);
675675
}
676+
if (width == 8) {
677+
return mlirDenseElementsAttrGetUInt8Value(*this, pos);
678+
}
679+
if (width == 16) {
680+
return mlirDenseElementsAttrGetUInt16Value(*this, pos);
681+
}
676682
if (width == 32) {
677683
return mlirDenseElementsAttrGetUInt32Value(*this, pos);
678684
}
@@ -683,6 +689,12 @@ class PyDenseIntElementsAttribute
683689
if (width == 1) {
684690
return mlirDenseElementsAttrGetBoolValue(*this, pos);
685691
}
692+
if (width == 8) {
693+
return mlirDenseElementsAttrGetInt8Value(*this, pos);
694+
}
695+
if (width == 16) {
696+
return mlirDenseElementsAttrGetInt16Value(*this, pos);
697+
}
686698
if (width == 32) {
687699
return mlirDenseElementsAttrGetInt32Value(*this, pos);
688700
}

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,16 @@ MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
426426
const int8_t *elements) {
427427
return getDenseAttribute(shapedType, numElements, elements);
428428
}
429+
MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
430+
intptr_t numElements,
431+
const uint16_t *elements) {
432+
return getDenseAttribute(shapedType, numElements, elements);
433+
}
434+
MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
435+
intptr_t numElements,
436+
const int16_t *elements) {
437+
return getDenseAttribute(shapedType, numElements, elements);
438+
}
429439
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
430440
intptr_t numElements,
431441
const uint32_t *elements) {
@@ -530,6 +540,12 @@ int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
530540
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
531541
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
532542
}
543+
int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
544+
return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
545+
}
546+
uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
547+
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
548+
}
533549
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
534550
return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
535551
}

mlir/test/CAPI/ir.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,8 @@ int printBuiltinAttributes(MlirContext ctx) {
904904
int bools[] = {0, 1};
905905
uint8_t uints8[] = {0u, 1u};
906906
int8_t ints8[] = {0, 1};
907+
uint16_t uints16[] = {0u, 1u};
908+
int16_t ints16[] = {0, 1};
907909
uint32_t uints32[] = {0u, 1u};
908910
int32_t ints32[] = {0, 1};
909911
uint64_t uints64[] = {0u, 1u};
@@ -921,6 +923,13 @@ int printBuiltinAttributes(MlirContext ctx) {
921923
MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
922924
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
923925
2, ints8);
926+
MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
927+
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
928+
encoding),
929+
2, uints16);
930+
MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
931+
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
932+
2, ints16);
924933
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
925934
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
926935
encoding),
@@ -956,6 +965,8 @@ int printBuiltinAttributes(MlirContext ctx) {
956965
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
957966
mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
958967
mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
968+
mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
969+
mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
959970
mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
960971
mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
961972
mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||

mlir/test/python/ir/attributes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,50 @@ def testDenseIntAttr():
292292
print(ShapedType(a.type).element_type)
293293

294294

295+
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
296+
@run
297+
def testDenseIntAttrGetItem():
298+
def print_item(attr_asm):
299+
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
300+
dtype = ShapedType(attr.type).element_type
301+
try:
302+
item = attr[0]
303+
print(f"{dtype}:", item)
304+
except TypeError as e:
305+
print(f"{dtype}:", e)
306+
307+
with Context():
308+
# CHECK: i1: 1
309+
print_item("dense<true> : tensor<i1>")
310+
# CHECK: i8: 123
311+
print_item("dense<123> : tensor<i8>")
312+
# CHECK: i16: 123
313+
print_item("dense<123> : tensor<i16>")
314+
# CHECK: i32: 123
315+
print_item("dense<123> : tensor<i32>")
316+
# CHECK: i64: 123
317+
print_item("dense<123> : tensor<i64>")
318+
# CHECK: ui8: 123
319+
print_item("dense<123> : tensor<ui8>")
320+
# CHECK: ui16: 123
321+
print_item("dense<123> : tensor<ui16>")
322+
# CHECK: ui32: 123
323+
print_item("dense<123> : tensor<ui32>")
324+
# CHECK: ui64: 123
325+
print_item("dense<123> : tensor<ui64>")
326+
# CHECK: si8: -123
327+
print_item("dense<-123> : tensor<si8>")
328+
# CHECK: si16: -123
329+
print_item("dense<-123> : tensor<si16>")
330+
# CHECK: si32: -123
331+
print_item("dense<-123> : tensor<si32>")
332+
# CHECK: si64: -123
333+
print_item("dense<-123> : tensor<si64>")
334+
335+
# CHECK: i7: Unsupported integer type
336+
print_item("dense<123> : tensor<i7>")
337+
338+
295339
# CHECK-LABEL: TEST: testDenseFPAttr
296340
@run
297341
def testDenseFPAttr():

0 commit comments

Comments
 (0)