Skip to content

Commit 2c07aa7

Browse files
committed
[mlir][emitc] Support dense as init value of constantOp
1 parent 7468718 commit 2c07aa7

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,16 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
13281328
}
13291329
};
13301330

1331+
auto getDenseElementType = [&](const Type &type) {
1332+
if (auto arrayType = dyn_cast<ArrayType>(type)) {
1333+
return arrayType.getElementType();
1334+
}
1335+
if (auto tensorType = dyn_cast<TensorType>(type)) {
1336+
return tensorType.getElementType();
1337+
}
1338+
return Type();
1339+
};
1340+
13311341
// Print floating point attributes.
13321342
if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
13331343
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
@@ -1362,17 +1372,17 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
13621372
}
13631373
}
13641374
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1365-
if (auto iType = dyn_cast<IntegerType>(
1366-
cast<TensorType>(dense.getType()).getElementType())) {
1375+
if (auto iType =
1376+
dyn_cast<IntegerType>(getDenseElementType(dense.getType()))) {
13671377
os << '{';
13681378
interleaveComma(dense, os, [&](const APInt &val) {
13691379
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
13701380
});
13711381
os << '}';
13721382
return success();
13731383
}
1374-
if (auto iType = dyn_cast<IndexType>(
1375-
cast<TensorType>(dense.getType()).getElementType())) {
1384+
if (auto iType =
1385+
dyn_cast<IndexType>(getDenseElementType(dense.getType()))) {
13761386
os << '{';
13771387
interleaveComma(dense, os,
13781388
[&](const APInt &val) { printInt(val, false); });

mlir/test/Target/Cpp/const.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ func.func @emitc_constant() {
1616
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
1717
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
1818
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
19+
%c11 = "emitc.constant"(){value = dense<[0, 1]> : !emitc.array<2xindex>} : () -> !emitc.array<2xindex>
20+
%c12 = "emitc.constant"(){value = dense<[0.0, 1.0]> : !emitc.array<2xf32>} : () -> !emitc.array<2xf32>
1921
return
2022
}
2123
// CPP-DEFAULT: void emitc_constant() {
@@ -33,6 +35,8 @@ func.func @emitc_constant() {
3335
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
3436
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
3537
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
38+
// CPP-DEFAULT-NEXT: size_t [[V11:[^ ]*]][2] = {0, 1};
39+
// CPP-DEFAULT-NEXT: float [[V12:[^ ]*]][2] = {0.0e+00f, 1.000000000e+00f};
3640

3741
// CPP-DECLTOP: void emitc_constant() {
3842
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -49,6 +53,8 @@ func.func @emitc_constant() {
4953
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
5054
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
5155
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
56+
// CPP-DECLTOP-NEXT: size_t [[V11:[^ ]*]][2];
57+
// CPP-DECLTOP-NEXT: float [[V12:[^ ]*]][2];
5258
// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX;
5359
// CPP-DECLTOP-NEXT: [[V1]] = 42;
5460
// CPP-DECLTOP-NEXT: [[V2]] = -1;
@@ -63,3 +69,5 @@ func.func @emitc_constant() {
6369
// CPP-DECLTOP-NEXT: [[V8]] = {0};
6470
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
6571
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
72+
// CPP-DECLTOP-NEXT: [[V11]] = {0, 1};
73+
// CPP-DECLTOP-NEXT: [[V12]] = {0.0e+00f, 1.000000000e+00f};

0 commit comments

Comments
 (0)