Skip to content

Commit 8f4da2c

Browse files
[mlir][affine] Fix min simplification in makeComposedAffineApply (#145376)
This patch fixes a bug discovered in the `affine::makeComposedFoldedAffineApply` function when `composeAffineMin == true`. The bug happened because the simplification assumed the symbols appearing in the `affine.apply` op corresponded to symbols in the `affine.min` op, and that's not always the case. For example: ```mlir #map = affine_map<()[s0, s1] -> (s1)> #map1 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> module { func.func @min_max_full_simplify() -> index { %0 = test.value_with_bounds {max = 64 : index, min = 32 : index} %1 = test.value_with_bounds {max = 64 : index, min = 32 : index} %2 = affine.min #map()[%0, %1] %3 = affine.apply #map1()[%2, %0] return %3 : index } } ``` This patch also introduces the test `make_composed_folded_affine_apply` transform operation to test this simplification. It also adds tests ensuring we get correct behavior. --------- Co-authored-by: Nicolas Vasilache <nico.vasilache@amd.com>
1 parent 1dc46d4 commit 8f4da2c

File tree

4 files changed

+199
-44
lines changed

4 files changed

+199
-44
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,59 +1046,81 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
10461046
map.getContext());
10471047
}
10481048

1049-
/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`.
1050-
/// Assuming that the quantity is of the form:
1051-
/// `affine_min(f(x, y), symbolic_cst)`.
1052-
/// This function checks that `0 < affine_min(f(x, y), symbolic_cst)` and
1053-
/// proceeds with replacing the patterns:
1049+
/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1050+
/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1051+
/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
10541052
/// ```
1055-
/// dimOrSym.ceildiv(symbolic_cst)
1056-
/// (dimOrSym + symbolic_cst - 1).floordiv(symbolic_cst)
1053+
/// dimOrSym.ceildiv(x_k)
1054+
/// (dimOrSym + x_k - 1).floordiv(x_k)
10571055
/// ```
1058-
/// by `1`.
1056+
/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
10591057
///
1060-
/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
1061-
/// inject static information that may not be statically discoverable.
10621058
///
10631059
/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1064-
/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
1065-
static LogicalResult replaceAffineMinBoundingBoxExpression(
1066-
AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map,
1067-
bool affineMinKnownToBeNonNegative = false) {
1068-
auto affineMinMap = minOp.getAffineMap();
1069-
if (!affineMinKnownToBeNonNegative) {
1070-
ValueRange values = minOp->getOperands();
1071-
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1072-
AffineMap row = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
1073-
FailureOr<int64_t> lowerBound =
1074-
ValueBoundsConstraintSet::computeConstantBound(
1075-
presburger::BoundType::LB, {row, values},
1076-
/*stopCondition=*/nullptr,
1077-
/*closedUB=*/true);
1078-
if (failed(lowerBound) || lowerBound.value() <= 0)
1079-
return failure();
1060+
/// `minOp` is positive.
1061+
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1062+
AffineExpr dimOrSym,
1063+
AffineMap *map,
1064+
ValueRange dims,
1065+
ValueRange syms) {
1066+
AffineMap affineMinMap = minOp.getAffineMap();
1067+
1068+
// Check the value is positive.
1069+
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1070+
// Compare each expression in the minimum against 0.
1071+
if (!ValueBoundsConstraintSet::compare(
1072+
getAsIndexOpFoldResult(minOp.getContext(), 0),
1073+
ValueBoundsConstraintSet::ComparisonOperator::LT,
1074+
ValueBoundsConstraintSet::Variable(affineMinMap.getSliceMap(i, 1),
1075+
minOp.getOperands())))
1076+
return failure();
1077+
}
1078+
1079+
/// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1080+
/// the apply op affine map.
1081+
DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1082+
SmallVector<unsigned> unmappedDims, unmappedSyms;
1083+
for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1084+
auto it = llvm::find(dims, dim);
1085+
if (it == dims.end()) {
1086+
unmappedDims.push_back(i);
1087+
continue;
10801088
}
1089+
dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1090+
getAffineDimExpr(it.getIndex(), minOp.getContext());
1091+
}
1092+
for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1093+
auto it = llvm::find(syms, sym);
1094+
if (it == syms.end()) {
1095+
unmappedSyms.push_back(i);
1096+
continue;
1097+
}
1098+
dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1099+
getAffineSymbolExpr(it.getIndex(), minOp.getContext());
10811100
}
10821101

1083-
AffineMap initialMap = *map;
1084-
for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
1085-
auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
1086-
AffineExpr expr = m.getResult(0);
1087-
if (!expr.isSymbolicOrConstant())
1102+
// Create the replacement map.
1103+
DenseMap<AffineExpr, AffineExpr> repl;
1104+
AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1105+
for (AffineExpr expr : affineMinMap.getResults()) {
1106+
// If we cannot express the result in terms of the apply map symbols and
1107+
// sims then continue.
1108+
if (llvm::any_of(unmappedDims,
1109+
[&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1110+
llvm::any_of(unmappedSyms,
1111+
[&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
10881112
continue;
10891113

1090-
DenseMap<AffineExpr, AffineExpr> repl;
1114+
AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1115+
10911116
// dimOrSym.ceilDiv(expr) -> 1
1092-
repl[dimOrSym.ceilDiv(expr)] = getAffineConstantExpr(1, minOp.getContext());
1117+
repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
10931118
// (dimOrSym + expr - 1).floorDiv(expr) -> 1
1094-
repl[(dimOrSym + expr - 1).floorDiv(expr)] =
1095-
getAffineConstantExpr(1, minOp.getContext());
1096-
auto newMap = map->replace(repl);
1097-
if (newMap == *map)
1098-
continue;
1099-
*map = newMap;
1119+
repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
11001120
}
1101-
1121+
AffineMap initialMap = *map;
1122+
*map = initialMap.replace(repl, initialMap.getNumDims(),
1123+
initialMap.getNumSymbols());
11021124
return success(*map != initialMap);
11031125
}
11041126

@@ -1127,11 +1149,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
11271149
if (!v)
11281150
return failure();
11291151

1130-
auto minOp = v.getDefiningOp<AffineMinOp>();
1131-
if (minOp && replaceAffineMin) {
1152+
if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
11321153
AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
11331154
: getAffineSymbolExpr(pos, ctx);
1134-
return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map);
1155+
return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1156+
syms);
11351157
}
11361158

11371159
auto affineApply = v.getDefiningOp<AffineApplyOp>();
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: mlir-opt --transform-interpreter %s | FileCheck %s
2+
3+
#map = affine_map<()[s0, s1] -> (s0, s1, 128)>
4+
#map1 = affine_map<()[s0, s1] -> (s0 ceildiv 128 + s0 ceildiv s1)>
5+
#map2 = affine_map<()[s0, s1, s2] -> (s0, s1 + s2)>
6+
#map3 = affine_map<()[s0, s1, s2, s3] -> (3 * (s0 ceildiv s3) + s0 ceildiv (s1 + s2))>
7+
#map4 = affine_map<()[s0, s1] -> (s1)>
8+
#map5 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
9+
#map6 = affine_map<()[s0, s1] -> (s0, s1, -128)>
10+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 ceildiv 128 + s0 ceildiv s1)>
11+
// CHECK-DAG: #[[MAP5:.*]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
12+
13+
// These test checks the `affine::makeComposedFoldedAffineApply` function when
14+
// `composeAffineMin == true`.
15+
16+
// Check the apply gets simplified.
17+
// CHECK: @apply_simplification
18+
func.func @apply_simplification_1() -> index {
19+
%0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
20+
%1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
21+
%2 = affine.min #map()[%0, %1]
22+
// CHECK-NOT: affine.apply
23+
// CHECK: arith.constant 2 : index
24+
%3 = affine.apply #map1()[%2, %1]
25+
return %3 : index
26+
}
27+
28+
// Check the simplification can match non-trivial affine expressions like s1 + s2.
29+
func.func @apply_simplification_2() -> index {
30+
%0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
31+
%1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
32+
%2 = test.value_with_bounds {max = 64 : index, min = 32 : index}
33+
%3 = affine.min #map2()[%0, %1, %2]
34+
// CHECK-NOT: affine.apply
35+
// CHECK: arith.constant 4 : index
36+
%4 = affine.apply #map3()[%3, %1, %2, %0]
37+
return %4 : index
38+
}
39+
40+
// Check there's no simplification.
41+
// The apply cannot be simplified because `s1 = %0` doesn't appear in the input min.
42+
// CHECK: @no_simplification_0
43+
func.func @no_simplification_0() -> index {
44+
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 64 : index, min = 32 : index}
45+
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 64 : index, min = 16 : index}
46+
// CHECK: %[[V2:.*]] = affine.min #{{.*}}()[%[[V0]], %[[V1]]]
47+
// CHECK: %[[V3:.*]] = affine.apply #[[MAP5]]()[%[[V2]], %[[V0]]]
48+
// CHECK: return %[[V3]] : index
49+
%0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
50+
%1 = test.value_with_bounds {max = 64 : index, min = 16 : index}
51+
%2 = affine.min #map4()[%0, %1]
52+
%3 = affine.apply #map5()[%2, %0]
53+
return %3 : index
54+
}
55+
56+
// The apply cannot be simplified because the min cannot be proven to be greater than 0.
57+
// CHECK: @no_simplification_1
58+
func.func @no_simplification_1() -> index {
59+
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 64 : index, min = 32 : index}
60+
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 64 : index, min = 16 : index}
61+
// CHECK: %[[V2:.*]] = affine.min #{{.*}}()[%[[V0]], %[[V1]]]
62+
// CHECK: %[[V3:.*]] = affine.apply #[[MAP1]]()[%[[V2]], %[[V1]]]
63+
// CHECK: return %[[V3]] : index
64+
%0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
65+
%1 = test.value_with_bounds {max = 64 : index, min = 16 : index}
66+
%2 = affine.min #map6()[%0, %1]
67+
%3 = affine.apply #map1()[%2, %1]
68+
return %3 : index
69+
}
70+
71+
module attributes {transform.with_named_sequence} {
72+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
73+
%0 = transform.structured.match ops{["affine.apply"]} in %arg0 : (!transform.any_op) -> !transform.any_op
74+
%1 = transform.test.make_composed_folded_affine_apply %0 : (!transform.any_op) -> !transform.any_op
75+
transform.yield
76+
}
77+
}

mlir/test/lib/Transforms/TestTransformsOps.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1515
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1616
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
18+
#include "mlir/IR/OpDefinition.h"
1719
#include "mlir/Transforms/RegionUtils.h"
1820

1921
#define GET_OP_CLASSES
@@ -56,6 +58,33 @@ transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
5658
return DiagnosedSilenceableFailure::success();
5759
}
5860

61+
//===----------------------------------------------------------------------===//
62+
// Test affine functionality.
63+
//===----------------------------------------------------------------------===//
64+
DiagnosedSilenceableFailure
65+
transform::TestMakeComposedFoldedAffineApply::applyToOne(
66+
TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
67+
ApplyToEachResultList &results, TransformState &state) {
68+
Location loc = affineApplyOp.getLoc();
69+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
70+
rewriter, loc, affineApplyOp.getAffineMap(),
71+
getAsOpFoldResult(affineApplyOp.getOperands()),
72+
/*composeAffineMin=*/true);
73+
Value result;
74+
if (auto v = dyn_cast<Value>(ofr)) {
75+
result = v;
76+
} else {
77+
result = rewriter.create<arith::ConstantIndexOp>(
78+
loc, getConstantIntValue(ofr).value());
79+
}
80+
results.push_back(result.getDefiningOp());
81+
rewriter.replaceOp(affineApplyOp, result);
82+
return DiagnosedSilenceableFailure::success();
83+
}
84+
85+
//===----------------------------------------------------------------------===//
86+
// Extension
87+
//===----------------------------------------------------------------------===//
5988
namespace {
6089

6190
class TestTransformsDialectExtension

mlir/test/lib/Transforms/TestTransformsOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,32 @@ def TestMoveValueDefns :
5959
}];
6060
}
6161

62+
//===----------------------------------------------------------------------===//
63+
// Test affine functionality.
64+
//===----------------------------------------------------------------------===//
65+
66+
def TestMakeComposedFoldedAffineApply :
67+
Op<Transform_Dialect, "test.make_composed_folded_affine_apply",
68+
[FunctionalStyleTransformOpTrait,
69+
MemoryEffectsOpInterface,
70+
TransformOpInterface,
71+
TransformEachOpTrait,
72+
ReportTrackingListenerFailuresOpTrait]> {
73+
let description = [{
74+
Rewrite an affine_apply by using the makeComposedFoldedAffineApply API.
75+
}];
76+
let arguments = (ins TransformHandleTypeInterface:$op);
77+
let results = (outs TransformHandleTypeInterface:$composed);
78+
let assemblyFormat = [{
79+
$op attr-dict `:` functional-type(operands, results)
80+
}];
81+
let extraClassDeclaration = [{
82+
::mlir::DiagnosedSilenceableFailure applyToOne(
83+
::mlir::transform::TransformRewriter &rewriter,
84+
::mlir::affine::AffineApplyOp affineApplyOp,
85+
::mlir::transform::ApplyToEachResultList &results,
86+
::mlir::transform::TransformState &state);
87+
}];
88+
}
6289

6390
#endif // TEST_TRANSFORM_OPS

0 commit comments

Comments
 (0)