Skip to content

[mlir][tosa] Print generic cond_if when block arguments are present #144859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Jun 19, 2025

The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information.

Example:

func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
    tosa.yield %1 : tensor<f32>
  },  {
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
    tosa.yield %1 : tensor<f32>
  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
  return %0 : tensor<f32>
}

$ mlir-opt test.mlir -o out.mlir
$ cat out.mlir
module {
  func.func @test_cond_if_generic_form(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
    %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
      %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
      tosa.yield %1 : tensor<f32>
    } else {
      %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
      tosa.yield %1 : tensor<f32>
    }
    return %0 : tensor<f32>
  }
}

Before mlir-opt, the cond_if is considered isolatedFromAbove. While after mlir-opt we've lost information about the block arguments, making the cond_if no longer isolatedFromAbove. This is a property that's currently required for TOSA specification compliant cond_if's.

The generic printer/parser captures information about block arguments
for then/else regions, while the simplified version does not. Currently
the simplified printer is preferred by default, which means information
about block arguments can be lost during a parse/print round-trip. This
commit changes that behaviour so that the generic printer is
preferred when block arguments have been provided, thus avoiding loss
of information.

Change-Id: Ia44fde857e6cd3a26dbc40c0a9187b4ddb95666b
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information.


Full diff: https://github.com/llvm/llvm-project/pull/144859.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+20-5)
  • (added) mlir/test/Dialect/Tosa/controlflow.mlir (+72)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a32e4ccbed594..d79a8760b5498 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
 
 // parse and print of IfOp refer to the implementation of SCF dialect.
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand cond;
+  // Fallback to generic IfOp parser when no immediate conditional
+  // operand is provided.
+  if (!parser.parseOptionalOperand(cond).has_value()) {
+    return parser.parseGenericOperationAfterOpName(result);
+  }
+
   // Create the regions for 'then'.
   result.regions.reserve(2);
   Region *thenRegion = result.addRegion();
   Region *elseRegion = result.addRegion();
 
   auto &builder = parser.getBuilder();
-  OpAsmParser::UnresolvedOperand cond;
   // Create a i1 tensor type for the boolean condition.
   Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
+  if (parser.resolveOperand(cond, i1Type, result.operands))
     return failure();
   // Parse optional results type list.
   if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void IfOp::print(OpAsmPrinter &p) {
+  // The simplified syntax drops block-level arguments
+  // to the then/else regions. Fallback to the generic
+  // parser if these are found
+  Region &thenRegion = getThenGraph();
+  Region &elseRegion = getElseGraph();
+  if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
+      !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
+    p.printGenericOp(*this, false);
+    return;
+  }
+
   bool printBlockTerminators = false;
 
   p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
     printBlockTerminators = true;
   }
   p << ' ';
-  p.printRegion(getThenGraph(),
+  p.printRegion(thenRegion,
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/printBlockTerminators);
 
   // Print the 'else' regions if it exists and has a block.
-  auto &elseRegion = getElseGraph();
   if (!elseRegion.empty()) {
     p << " else ";
     p.printRegion(elseRegion,
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..3bc088d02b22c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form
+// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
+// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
+// CHECK:    %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK:    tosa.yield %[[THEN_TERM]] : tensor<f32>
+// CHECK: }, {
+// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
+// CHECK:    %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK:    tosa.yield %[[ELSE_TERM]] : tensor<f32>
+// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: return %[[OUT]] : tensor<f32>
+func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
+// COM: No block arguments means simplified form can be printed
+func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  // CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+  %0 = tosa.cond_if(%arg2) ({
+  ^bb0():
+    tosa.yield %arg0 : tensor<f32>
+  },  {
+  ^bb0():
+    tosa.yield %arg1 : tensor<f32>
+  }) : (tensor<i1>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form_just_yield
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    tosa.yield %arg0 : tensor<f32>
+  } else {
+    tosa.yield %arg1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information.


Full diff: https://github.com/llvm/llvm-project/pull/144859.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+20-5)
  • (added) mlir/test/Dialect/Tosa/controlflow.mlir (+72)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a32e4ccbed594..d79a8760b5498 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
 
 // parse and print of IfOp refer to the implementation of SCF dialect.
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand cond;
+  // Fallback to generic IfOp parser when no immediate conditional
+  // operand is provided.
+  if (!parser.parseOptionalOperand(cond).has_value()) {
+    return parser.parseGenericOperationAfterOpName(result);
+  }
+
   // Create the regions for 'then'.
   result.regions.reserve(2);
   Region *thenRegion = result.addRegion();
   Region *elseRegion = result.addRegion();
 
   auto &builder = parser.getBuilder();
-  OpAsmParser::UnresolvedOperand cond;
   // Create a i1 tensor type for the boolean condition.
   Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
+  if (parser.resolveOperand(cond, i1Type, result.operands))
     return failure();
   // Parse optional results type list.
   if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void IfOp::print(OpAsmPrinter &p) {
+  // The simplified syntax drops block-level arguments
+  // to the then/else regions. Fallback to the generic
+  // parser if these are found
+  Region &thenRegion = getThenGraph();
+  Region &elseRegion = getElseGraph();
+  if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
+      !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
+    p.printGenericOp(*this, false);
+    return;
+  }
+
   bool printBlockTerminators = false;
 
   p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
     printBlockTerminators = true;
   }
   p << ' ';
-  p.printRegion(getThenGraph(),
+  p.printRegion(thenRegion,
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/printBlockTerminators);
 
   // Print the 'else' regions if it exists and has a block.
-  auto &elseRegion = getElseGraph();
   if (!elseRegion.empty()) {
     p << " else ";
     p.printRegion(elseRegion,
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..3bc088d02b22c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form
+// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
+// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
+// CHECK:    %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK:    tosa.yield %[[THEN_TERM]] : tensor<f32>
+// CHECK: }, {
+// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
+// CHECK:    %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK:    tosa.yield %[[ELSE_TERM]] : tensor<f32>
+// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: return %[[OUT]] : tensor<f32>
+func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
+// COM: No block arguments means simplified form can be printed
+func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  // CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+  %0 = tosa.cond_if(%arg2) ({
+  ^bb0():
+    tosa.yield %arg0 : tensor<f32>
+  },  {
+  ^bb0():
+    tosa.yield %arg1 : tensor<f32>
+  }) : (tensor<i1>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form_just_yield
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    tosa.yield %arg0 : tensor<f32>
+  } else {
+    tosa.yield %arg1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}

@amirBish
Copy link
Contributor

LGTM :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants