Skip to content

Commit 3eebd3f

Browse files
committed
[mlir][tosa] Interpret boolean values correctly in cast folder
Previously the cast folder would sign extend boolean values, leading "true" to be casted to a value of -1 instead of 1. This change ensures i1 values are zero extended, since i1 is used as a boolean value in TOSA. According to the TOSA spec, the result of a boolean cast with value "true" to another integer type should result in "1". Fixes #57951 Change-Id: I21486cae0b8ad1cf7901e7b3eb92c1ad56c3797c
1 parent b9b2661 commit 3eebd3f

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,15 +1301,17 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
13011301
}
13021302

13031303
if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1304-
auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1304+
const auto inIntType = llvm::cast<IntegerType>(inETy);
1305+
auto unsignIn = inIntType.isUnsignedInteger();
13051306
bool trunc =
13061307
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
13071308
auto intVal = operand.getSplatValue<APInt>();
13081309
auto bitwidth = outETy.getIntOrFloatBitWidth();
13091310

13101311
if (trunc) {
13111312
intVal = intVal.trunc(bitwidth);
1312-
} else if (unsignIn) {
1313+
// i1 types are boolean in TOSA
1314+
} else if (unsignIn || inIntType.isInteger(1)) {
13131315
intVal = intVal.zext(bitwidth);
13141316
} else {
13151317
intVal = intVal.sext(bitwidth);

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,3 +1338,14 @@ func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
13381338
%3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
13391339
return %3 : tensor<i32>
13401340
}
1341+
1342+
// -----
1343+
1344+
// CHECK-LABEL: @test_fold_i1_to_i32_cast
1345+
// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
1346+
// CHECK: return %[[OUT]] : tensor<i32>
1347+
func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
1348+
%0 = "tosa.const"() <{values = dense<1> : tensor<i1>}> : () -> tensor<i1>
1349+
%1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
1350+
return %1 : tensor<i32>
1351+
}

0 commit comments

Comments
 (0)