-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][tosa] Print generic cond_if
when block arguments are present
#144859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information. Change-Id: Ia44fde857e6cd3a26dbc40c0a9187b4ddb95666b
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThe generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information. Full diff: https://github.com/llvm/llvm-project/pull/144859.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a32e4ccbed594..d79a8760b5498 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand cond;
+ // Fallback to generic IfOp parser when no immediate conditional
+ // operand is provided.
+ if (!parser.parseOptionalOperand(cond).has_value()) {
+ return parser.parseGenericOperationAfterOpName(result);
+ }
+
// Create the regions for 'then'.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand cond;
// Create a i1 tensor type for the boolean condition.
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+ if (parser.resolveOperand(cond, i1Type, result.operands))
return failure();
// Parse optional results type list.
if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
+ // The simplified syntax drops block-level arguments
+ // to the then/else regions. Fallback to the generic
+ // parser if these are found
+ Region &thenRegion = getThenGraph();
+ Region &elseRegion = getElseGraph();
+ if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
+ !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
+ p.printGenericOp(*this, false);
+ return;
+ }
+
bool printBlockTerminators = false;
p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
printBlockTerminators = true;
}
p << ' ';
- p.printRegion(getThenGraph(),
+ p.printRegion(thenRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
// Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..3bc088d02b22c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form
+// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
+// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
+// CHECK: %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: tosa.yield %[[THEN_TERM]] : tensor<f32>
+// CHECK: }, {
+// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
+// CHECK: %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: tosa.yield %[[ELSE_TERM]] : tensor<f32>
+// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: return %[[OUT]] : tensor<f32>
+func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
+// COM: No block arguments means simplified form can be printed
+func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+ %0 = tosa.cond_if(%arg2) ({
+ ^bb0():
+ tosa.yield %arg0 : tensor<f32>
+ }, {
+ ^bb0():
+ tosa.yield %arg1 : tensor<f32>
+ }) : (tensor<i1>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form_just_yield
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ tosa.yield %arg0 : tensor<f32>
+ } else {
+ tosa.yield %arg1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThe generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information. Full diff: https://github.com/llvm/llvm-project/pull/144859.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a32e4ccbed594..d79a8760b5498 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand cond;
+ // Fallback to generic IfOp parser when no immediate conditional
+ // operand is provided.
+ if (!parser.parseOptionalOperand(cond).has_value()) {
+ return parser.parseGenericOperationAfterOpName(result);
+ }
+
// Create the regions for 'then'.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand cond;
// Create a i1 tensor type for the boolean condition.
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+ if (parser.resolveOperand(cond, i1Type, result.operands))
return failure();
// Parse optional results type list.
if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
+ // The simplified syntax drops block-level arguments
+ // to the then/else regions. Fallback to the generic
+ // parser if these are found
+ Region &thenRegion = getThenGraph();
+ Region &elseRegion = getElseGraph();
+ if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
+ !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
+ p.printGenericOp(*this, false);
+ return;
+ }
+
bool printBlockTerminators = false;
p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
printBlockTerminators = true;
}
p << ' ';
- p.printRegion(getThenGraph(),
+ p.printRegion(thenRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
// Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..3bc088d02b22c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form
+// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
+// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
+// CHECK: %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: tosa.yield %[[THEN_TERM]] : tensor<f32>
+// CHECK: }, {
+// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
+// CHECK: %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: tosa.yield %[[ELSE_TERM]] : tensor<f32>
+// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: return %[[OUT]] : tensor<f32>
+func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
+// COM: No block arguments means simplified form can be printed
+func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+ %0 = tosa.cond_if(%arg2) ({
+ ^bb0():
+ tosa.yield %arg0 : tensor<f32>
+ }, {
+ ^bb0():
+ tosa.yield %arg1 : tensor<f32>
+ }) : (tensor<i1>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form_just_yield
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ tosa.yield %arg0 : tensor<f32>
+ } else {
+ tosa.yield %arg1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
|
LGTM :) |
The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information.
Example:
Before
mlir-opt
, the cond_if is consideredisolatedFromAbove
. While aftermlir-opt
we've lost information about the block arguments, making the cond_if no longerisolatedFromAbove
. This is a property that's currently required for TOSA specification compliant cond_if's.