Skip to content

Commit 802cd0d

Browse files
committed
[mlir][emitc] Support dense as init value of constantOp
1 parent d196124 commit 802cd0d

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,8 +1362,9 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
13621362
}
13631363
}
13641364
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1365+
assert(isa<ShapedType>(dense.getType()) && "Expected shaped type");
13651366
if (auto iType = dyn_cast<IntegerType>(
1366-
cast<TensorType>(dense.getType()).getElementType())) {
1367+
cast<ShapedType>(dense.getType()).getElementType())) {
13671368
os << '{';
13681369
interleaveComma(dense, os, [&](const APInt &val) {
13691370
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
@@ -1372,7 +1373,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
13721373
return success();
13731374
}
13741375
if (auto iType = dyn_cast<IndexType>(
1375-
cast<TensorType>(dense.getType()).getElementType())) {
1376+
cast<ShapedType>(dense.getType()).getElementType())) {
13761377
os << '{';
13771378
interleaveComma(dense, os,
13781379
[&](const APInt &val) { printInt(val, false); });

mlir/test/Target/Cpp/const.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ func.func @emitc_constant() {
1616
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
1717
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
1818
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
19+
%c11 = "emitc.constant"(){value = dense<[0, 1]> : !emitc.array<2xindex>} : () -> !emitc.array<2xindex>
20+
%c12 = "emitc.constant"(){value = dense<[0.0, 1.0]> : !emitc.array<2xf32>} : () -> !emitc.array<2xf32>
1921
return
2022
}
2123
// CPP-DEFAULT: void emitc_constant() {
@@ -33,6 +35,8 @@ func.func @emitc_constant() {
3335
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
3436
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
3537
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
38+
// CPP-DEFAULT-NEXT: size_t [[V11:[^ ]*]][2] = {0, 1};
39+
// CPP-DEFAULT-NEXT: float [[V12:[^ ]*]][2] = {0.0e+00f, 1.000000000e+00f};
3640

3741
// CPP-DECLTOP: void emitc_constant() {
3842
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -49,6 +53,8 @@ func.func @emitc_constant() {
4953
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
5054
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
5155
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
56+
// CPP-DECLTOP-NEXT: size_t [[V11:[^ ]*]][2];
57+
// CPP-DECLTOP-NEXT: float [[V12:[^ ]*]][2];
5258
// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX;
5359
// CPP-DECLTOP-NEXT: [[V1]] = 42;
5460
// CPP-DECLTOP-NEXT: [[V2]] = -1;
@@ -63,3 +69,5 @@ func.func @emitc_constant() {
6369
// CPP-DECLTOP-NEXT: [[V8]] = {0};
6470
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
6571
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
72+
// CPP-DECLTOP-NEXT: [[V11]] = {0, 1};
73+
// CPP-DECLTOP-NEXT: [[V12]] = {0.0e+00f, 1.000000000e+00f};

0 commit comments

Comments
 (0)