diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir index e32ea7ad3c729..e5dbec44220b7 100644 --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -1426,12 +1426,3 @@ func.func @wrong_weights_number_in_if_then_else(%cond: i1) { } return } - -// ----- - -func.func @negative_weight_in_if_then(%cond: i1) { -// expected-error @below {{weight #0 must be non-negative}} - fir.if %cond weights([-1, 101]) { - } - return -} diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 46ab0b9ebbc6b..b8d08cc553caa 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -384,13 +384,15 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> { This interface provides weight information for branching terminator operations, i.e. terminator operations with successors. - This interface provides methods for getting/setting integer non-negative - weight of each branch. The probability of executing a branch - is computed as the ratio between the branch's weight and the total - sum of the weights (which cannot be zero). - The weights are optional. If they are provided, then their number + This interface provides methods for getting/setting integer weights of each + branch. The probability of executing a branch is computed as the ratio + between the branch's weight and the total sum of the weights (which cannot + be zero). The weights are optional. If they are provided, then their number must match the number of successors of the operation. + Note that the branch weight use an i32 representation but they are to be + interpreted as unsigned integers. + The default implementations of the methods expect the operation to have an attribute of type DenseI32ArrayAttr named branch_weights. }]; @@ -440,19 +442,21 @@ def WeightedRegionBranchOpInterface This interface provides weight information for region operations that exhibit branching behavior between held regions. - This interface provides methods for getting/setting integer non-negative - weight of each branch. The probability of executing a region is computed - as the ratio between the region branch's weight and the total sum - of the weights (which cannot be zero). - The weights are optional. If they are provided, then their number - must match the number of regions held by the operation - (including empty regions). + This interface provides methods for getting/setting integer weights of each + branch. The probability of executing a region is computed as the ratio + between the region branch's weight and the total sum of the weights (which + cannot be zero). The weights are optional. If they are provided, then their + number must match the number of regions held by the operation (including + empty regions). The weights specify the probability of branching to a particular region when first executing the operation. For example, for loop-like operations with a single region the weight specifies the probability of entering the loop. + Note that the branch weight use an i32 representation but they are to be + interpreted as unsigned integers. + The default implementations of the methods expect the operation to have an attribute of type DenseI32ArrayAttr named branch_weights. }]; diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 3a63db35eec0f..e87bb461b0329 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -99,10 +99,6 @@ static LogicalResult verifyWeights(Operation *op, << ": " << weights.size() << " vs " << expectedWeightsNum; - for (auto [index, weight] : llvm::enumerate(weights)) - if (weight < 0) - return op->emitError() << "weight #" << index << " must be non-negative"; - if (llvm::all_of(weights, [](int32_t value) { return value == 0; })) return op->emitError() << "branch weights cannot all be zero"; diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir index 1b8de22a9ff9f..0a71c62ec31af 100644 --- a/mlir/test/Dialect/ControlFlow/invalid.mlir +++ b/mlir/test/Dialect/ControlFlow/invalid.mlir @@ -82,18 +82,6 @@ func.func @wrong_weights_number(%cond: i1) { // ----- -// CHECK-LABEL: func @negative_weight -func.func @wrong_total_weight(%cond: i1) { - // expected-error@+1 {{weight #0 must be non-negative}} - cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2 - ^bb1: - return - ^bb2: - return -} - -// ----- - // CHECK-LABEL: func @zero_weights func.func @wrong_total_weight(%cond: i1) { // expected-error@+1 {{branch weights cannot all be zero}}