diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 1d21096e8920b..4bb69d4ec4820 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1301,7 +1301,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { } if (llvm::isa(inETy) && llvm::isa(outETy)) { - auto unsignIn = llvm::cast(inETy).isUnsignedInteger(); + const auto inIntType = llvm::cast(inETy); + auto unsignIn = inIntType.isUnsignedInteger(); bool trunc = inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); auto intVal = operand.getSplatValue(); @@ -1309,7 +1310,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { if (trunc) { intVal = intVal.trunc(bitwidth); - } else if (unsignIn) { + // i1 types are boolean in TOSA + } else if (unsignIn || inIntType.isInteger(1)) { intVal = intVal.zext(bitwidth); } else { intVal = intVal.sext(bitwidth); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 27280807b0282..11c8d54fda055 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1338,3 +1338,14 @@ func.func @no_fold_mul_result_exceeds_i32() -> tensor { %3 = tosa.mul %0, %1, %2 : (tensor, tensor, tensor<1xi8>) -> tensor return %3 : tensor } + +// ----- + +// CHECK-LABEL: @test_fold_i1_to_i32_cast +// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor +// CHECK: return %[[OUT]] : tensor +func.func @test_fold_i1_to_i32_cast() -> tensor { + %0 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor + return %1 : tensor +}