diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 2dd45d27157cb..8f698e4b0dffb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1303,7 +1303,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { if (trunc) { intVal = intVal.trunc(bitwidth); - } else if (unsignIn) { + } else if (unsignIn || inETy.getIntOrFloatBitWidth() == 1) { + // Casting from i1 to iX will treat it as unsigned. intVal = intVal.zext(bitwidth); } else { intVal = intVal.sext(bitwidth); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index d9d188dd25061..e803105f719db 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -594,6 +594,16 @@ func.func @cast_int_to_int_sign() -> tensor { return %cast : tensor } + +// CHECK: func.func @cast_i1_true_to_i32 +func.func @cast_i1_true_to_i32() -> tensor { + %splat = "tosa.const"() {values = dense : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<1> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + // ----- // CHECK-LABEL: @reverse_splat