diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index d33fc902de3a1..30e85ba92494c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1193,12 +1193,23 @@ bool checkErrorIfPad(Operation *op) { return true; } -// Returns true if the operation takes no input operands, excluding attributes. -static bool isNullaryOperation(Operation *op) { - if (isa(op) || isa(op) || - isa(op) || isa(op)) - return true; - return false; +static bool isOpIsolatedFromAbove(Operation *op, Region ®ion) { + return llvm::all_of(op->getOperands(), [&](auto operand) { + Region *operandRegion = operand.getParentRegion(); + return region.isAncestor(operandRegion); + }); +} + +static bool isRegionIsolatedFromAbove(Region ®ion) { + bool noLiveInValue = true; + region.walk([&noLiveInValue, ®ion](Operation *op) { + if (!isOpIsolatedFromAbove(op, region)) { + noLiveInValue = false; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return noLiveInValue; } bool checkErrorIfCondIf(Operation *op) { @@ -1206,48 +1217,75 @@ bool checkErrorIfCondIf(Operation *op) { if (!ifOp) return true; - // Whether the types and shapes of operands between the input/output list and - // internal regions are validated by the operation verifier. However, with - // support for the simplified form - where redundant operand notations are - // omitted - is not conformant to the specification. According to the - // specification, all operands passed into an operation must be explicitly - // declared at each operation's structure. This code section verify that the - // operation's form complies with this requirement. - - // Returns true if the region uses no external input operands. - auto isNullaryRegion = [](Region ®ion) -> bool { - bool noLiveInValue = true; - region.walk([&noLiveInValue](Operation *op) { - if (!isNullaryOperation(op)) { - noLiveInValue = false; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return noLiveInValue; - }; + // Currently the dialect supports declaring cond_if operations that + // have then/else regions that reference values from outside these + // regions. According to the specification, all values used by the + // then/else regions must be explicitly declared within the regions. + // Therefore we must check that the then/else regions are + // "isolated from above", in order to be conformant to the + // specification. + // + // Note: the dialect currently supports two styles of syntax for + // declaring "cond_if" operations. We'll refer to these as follows: + // + // Generic: + // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({ + // ^bb0(%arg3, %arg4): + // tosa.yield %arg3 + // }, { + // ^bb0(%arg3, %arg4): + // tosa.yield %arg4 + // }) + // + // Simplified: + // %0 = tosa.cond_if %arg2 { + // tosa.yield %arg0 + // } else { + // tosa.yield %arg1 + // } + // + // Unfortunately, the simplified syntax does not encapsulate values + // used in then/else regions (see 'simplified' example above), so it + // must be rewritten to use the generic syntax in order to be conformant + // to the specification. + Region &thenGraph = ifOp.getThenGraph(); + Region &elseGraph = ifOp.getElseGraph(); + bool isThenGraphIsolatedRegion = isRegionIsolatedFromAbove(thenGraph); + bool isElseGraphIsolatedRegion = isRegionIsolatedFromAbove(elseGraph); + + if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) { + op->emitOpError() + << "is not conformant to the TOSA specification. It requires the " + "then/else regions are isolated from above.\n"; + return false; + } + return true; +} - mlir::Region &thenGraph = ifOp.getThenGraph(); - mlir::Region &elseGraph = ifOp.getElseGraph(); - bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph); - bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph); - bool isInputListEmpty = ifOp.getInputList().size() == 0; +bool checkErrorIfWhileLoop(Operation *op) { + auto whileOp = dyn_cast(op); + if (!whileOp) + return true; - if ((isInputListEmpty != isThenGraphNullaryRegion) || - (isInputListEmpty != isElseGraphNullaryRegion)) { + Region &condGraph = whileOp.getCondGraph(); + Region &bodyGraph = whileOp.getBodyGraph(); + bool isCondGraphIsolatedRegion = isRegionIsolatedFromAbove(condGraph); + bool isBodyGraphIsolatedRegion = isRegionIsolatedFromAbove(bodyGraph); + + if (!isCondGraphIsolatedRegion || !isBodyGraphIsolatedRegion) { op->emitOpError() - << "the current simplified form is not strictly conformant to the " - "spec, please use the generic format\n"; + << "is not conformant to the TOSA specification. It requires the " + "cond/body regions are isolated from above.\n"; return false; } - return true; } LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op) || !checkErrorIfRescale(op) || - !checkErrorIfPad(op) || !checkErrorIfCondIf(op)) + !checkErrorIfPad(op) || !checkErrorIfCondIf(op) || + !checkErrorIfWhileLoop(op)) return failure(); return success(); } diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index 1f25132d6bcf3..77830c7be2e9e 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -227,15 +227,79 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> } // ----- -// CHECK-LABEL: cond_if_simplified_form -func.func @test_cond_if_simplified_form(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // expected-error@+1 {{'tosa.cond_if' op the current simplified form is not strictly conformant to the spec, please use the generic format}} + +func.func @test_cond_if_not_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}} + %0 = "tosa.cond_if"(%arg2) ({ + ^bb0(): + tosa.yield %arg0 : tensor + }, { + ^bb0(): + tosa.yield %arg1 : tensor + }) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}} %0 = tosa.cond_if %arg2 -> (tensor) { - %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor - tosa.yield %1 : tensor + tosa.yield %arg0 : tensor } else { - %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor - tosa.yield %1 : tensor + tosa.yield %arg1 : tensor } return %0 : tensor } + +// ----- + +// COM: Check isolated cond_if's are valid +func.func @test_cond_if_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg4 : tensor + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_while_loop_not_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the cond/body regions are isolated from above.}} + %1 = "tosa.while_loop"(%0) ({ + ^bb0(%arg3: tensor): + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + tosa.yield %3 : tensor + }, { + ^bb0(%arg3: tensor): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + tosa.yield %3 : tensor + }) : (tensor) -> (tensor) + return +} + +// ----- + +// COM: Check isolated while_loops are valid +func.func @test_while_loop_isolated_from_above(%arg0: tensor, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): + %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%3, %arg4, %arg5) : (tensor, tensor, tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return +}