From d20d1666f180567f4b81575b0a974be101ffba1a Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 11 Jun 2025 10:22:06 +0000 Subject: [PATCH] [mlir][tosa] Print generic `cond_if` when block arguments are present The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information. Change-Id: Ia44fde857e6cd3a26dbc40c0a9187b4ddb95666b --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 25 +++++++-- mlir/test/Dialect/Tosa/controlflow.mlir | 72 +++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 mlir/test/Dialect/Tosa/controlflow.mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index a32e4ccbed594..d79a8760b5498 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3649,17 +3649,22 @@ std::optional> ApplyScaleOp::getShapeForUnroll() { // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand cond; + // Fallback to generic IfOp parser when no immediate conditional + // operand is provided. + if (!parser.parseOptionalOperand(cond).has_value()) { + return parser.parseGenericOperationAfterOpName(result); + } + // Create the regions for 'then'. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); - OpAsmParser::UnresolvedOperand cond; // Create a i1 tensor type for the boolean condition. Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) + if (parser.resolveOperand(cond, i1Type, result.operands)) return failure(); // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) @@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { + // The simplified syntax drops block-level arguments + // to the then/else regions. Fallback to the generic + // parser if these are found + Region &thenRegion = getThenGraph(); + Region &elseRegion = getElseGraph(); + if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 && + !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) { + p.printGenericOp(*this, false); + return; + } + bool printBlockTerminators = false; p << " " << getCondition(); @@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) { printBlockTerminators = true; } p << ' '; - p.printRegion(getThenGraph(), + p.printRegion(thenRegion, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); // Print the 'else' regions if it exists and has a block. - auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion, diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir new file mode 100644 index 0000000000000..3bc088d02b22c --- /dev/null +++ b/mlir/test/Dialect/Tosa/controlflow.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + +// ----- + +// CHECK-LABEL: test_cond_if_generic_form +// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({ +// CHECK: ^bb0(%[[INA:.*]]: tensor, %[[INB:.*]]: tensor): +// CHECK: %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor, tensor) -> tensor +// CHECK: tosa.yield %[[THEN_TERM]] : tensor +// CHECK: }, { +// CHECK: ^bb0(%[[INC:.*]]: tensor, %[[IND:.*]]: tensor): +// CHECK: %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor, tensor) -> tensor +// CHECK: tosa.yield %[[ELSE_TERM]] : tensor +// CHECK: }) : (tensor, tensor, tensor) -> tensor +// CHECK: return %[[OUT]] : tensor +func.func @test_cond_if_generic_form(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments +// COM: No block arguments means simplified form can be printed +func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: tosa.cond_if %arg2 -> (tensor) + %0 = tosa.cond_if(%arg2) ({ + ^bb0(): + tosa.yield %arg0 : tensor + }, { + ^bb0(): + tosa.yield %arg1 : tensor + }) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_cond_if_simplified_form +// CHECK: tosa.cond_if %arg2 -> (tensor) +func.func @test_cond_if_simplified_form(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_cond_if_simplified_form_just_yield +// CHECK: tosa.cond_if %arg2 -> (tensor) +func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = tosa.cond_if %arg2 -> (tensor) { + tosa.yield %arg0 : tensor + } else { + tosa.yield %arg1 : tensor + } + return %0 : tensor +}