Skip to content

Commit ddff376

Browse files
committed
[MLIR] Simplify affine maps + operands exploiting IV info
Simplify affine expressions and maps while exploiting simple range and step info of any IVs that are operands. This simplification is local, O(1) and practically useful in several scenarios. Accesses with floordiv's and mod's where the LHS is non-negative and bounded or is a known multiple of a constant can often be simplified. This is implemented as a canonicalization for all affine ops in a generic way: all affine.load/store, vector_load/store, affine.apply, affine.min/max, etc. ops. Eg: For tiled loop nests accessing buffers this way: affine.for %i = 0 to 1024 step 32 { affine.for %ii = 0 to 32 { affine.load [(%i + %ii) floordiv 32, (%i + %ii) mod 32] } } // Note that %i is a multiple of 32 and %ii < 32, hence: (%i + %ii) floordiv 32 is the same as %i floordiv 32 (%i + %ii) mod 32 is the same as %ii mod 32. The simplification leads to simpler index/subscript arithmetic for multi-dimensional arrays and also in turn enables detection of spatial locality (for vectorization for eg.), temporal locality or loop invariance for hoisting or scalar replacement. Differential Revision: https://reviews.llvm.org/D135085
1 parent 82cac65 commit ddff376

File tree

4 files changed

+225
-4
lines changed

4 files changed

+225
-4
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ class AffineMap {
327327
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
328328
AffineMap getMinorSubMap(unsigned numResults) const;
329329

330+
/// Get the largest known divisor of all map expressions.
331+
/// For eg: for (d0, d1) -> (8*d0 + 4, 4*d1 + 2), the result is 2.
332+
/// In the case of maps with no expressions or all zero constant expressions,
333+
/// the largest known divisor is trivially the max uint64_t value.
334+
uint64_t getLargestKnownDivisorOfMapExprs();
335+
330336
friend ::llvm::hash_code hash_value(AffineMap arg);
331337

332338
/// Methods supporting C API.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallBitVector.h"
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/Support/Debug.h"
24+
#include <numeric>
2425

2526
using namespace mlir;
2627

@@ -578,6 +579,169 @@ OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
578579
return result[0];
579580
}
580581

582+
/// Returns the largest known divisor of `e`. Exploits information from the
583+
/// values in `operands`.
584+
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
585+
// This method isn't aware of `operands`.
586+
int64_t div = e.getLargestKnownDivisor();
587+
588+
// We now make use of operands for the case `e` is a dim expression.
589+
// TODO: More powerful simplification would have to modify
590+
// getLargestKnownDivisor to take `operands` and exploit that information as
591+
// well for dim/sym expressions, but in that case, getLargestKnownDivisor
592+
// can't be part of the IR library but of the `Analysis` library. The IR
593+
// library can only really depend on simple O(1) checks.
594+
auto dimExpr = e.dyn_cast<AffineDimExpr>();
595+
// If it's not a dim expr, `div` is the best we have.
596+
if (!dimExpr)
597+
return div;
598+
599+
// We simply exploit information from loop IVs.
600+
// We don't need to use mlir::getLargestKnownDivisorOfValue since the other
601+
// desired simplifications are expected to be part of other
602+
// canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
603+
// LoopAnalysis library.
604+
Value operand = operands[dimExpr.getPosition()];
605+
int64_t operandDivisor = 1;
606+
// TODO: With the right accessors, this can be extended to
607+
// LoopLikeOpInterface.
608+
if (AffineForOp forOp = getForInductionVarOwner(operand)) {
609+
if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
610+
operandDivisor = forOp.getStep();
611+
} else {
612+
uint64_t lbLargestKnownDivisor =
613+
forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
614+
operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStep());
615+
}
616+
}
617+
return operandDivisor;
618+
}
619+
620+
/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
621+
/// being an affine dim expression or a constant.
622+
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
623+
int64_t k) {
624+
if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
625+
int64_t constVal = constExpr.getValue();
626+
return constVal >= 0 && constVal < k;
627+
}
628+
auto dimExpr = e.dyn_cast<AffineDimExpr>();
629+
if (!dimExpr)
630+
return false;
631+
Value operand = operands[dimExpr.getPosition()];
632+
// TODO: With the right accessors, this can be extended to
633+
// LoopLikeOpInterface.
634+
if (AffineForOp forOp = getForInductionVarOwner(operand)) {
635+
if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
636+
forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
637+
return true;
638+
}
639+
}
640+
641+
// We don't consider other cases like `operand` being defined by a constant or
642+
// an affine.apply op since such cases will already be handled by other
643+
// patterns and propagation of loop IVs or constant would happen.
644+
return false;
645+
}
646+
647+
/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
648+
/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
649+
/// expression is in that form.
650+
static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
651+
AffineExpr &quotientTimesDiv, AffineExpr &rem) {
652+
auto bin = e.dyn_cast<AffineBinaryOpExpr>();
653+
if (!bin || bin.getKind() != AffineExprKind::Add)
654+
return false;
655+
656+
AffineExpr llhs = bin.getLHS();
657+
AffineExpr rlhs = bin.getRHS();
658+
div = getLargestKnownDivisor(llhs, operands);
659+
if (isNonNegativeBoundedBy(rlhs, operands, div)) {
660+
quotientTimesDiv = llhs;
661+
rem = rlhs;
662+
return true;
663+
}
664+
div = getLargestKnownDivisor(rlhs, operands);
665+
if (isNonNegativeBoundedBy(llhs, operands, div)) {
666+
quotientTimesDiv = rlhs;
667+
rem = llhs;
668+
return true;
669+
}
670+
return false;
671+
}
672+
673+
/// Simplify `expr` while exploiting information from the values in `operands`.
674+
static void simplifyExprAndOperands(AffineExpr &expr,
675+
ArrayRef<Value> operands) {
676+
// We do this only for certain floordiv/mod expressions.
677+
auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
678+
if (!binExpr)
679+
return;
680+
681+
// Simplify the child expressions first.
682+
auto lhs = binExpr.getLHS();
683+
auto rhs = binExpr.getRHS();
684+
simplifyExprAndOperands(lhs, operands);
685+
simplifyExprAndOperands(rhs, operands);
686+
expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
687+
688+
binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
689+
if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv &&
690+
binExpr.getKind() != AffineExprKind::Mod)) {
691+
return;
692+
}
693+
694+
auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
695+
if (!rhsConst)
696+
return;
697+
698+
int64_t rhsConstVal = rhsConst.getValue();
699+
AffineExpr quotientTimesDiv, rem;
700+
int64_t divisor;
701+
702+
// Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
703+
// mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
704+
// `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
705+
// And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
706+
if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
707+
if (rhsConstVal % divisor == 0 &&
708+
binExpr.getKind() == AffineExprKind::FloorDiv) {
709+
expr = quotientTimesDiv.floorDiv(rhsConst);
710+
} else if (divisor % rhsConstVal == 0 &&
711+
binExpr.getKind() == AffineExprKind::Mod) {
712+
expr = rem % rhsConst;
713+
}
714+
return;
715+
}
716+
717+
// Handle the simple case when the LHS expression can be either upper
718+
// bounded or is a known multiple of RHS constant.
719+
// lhs floordiv c -> 0 if 0 <= lhs < c,
720+
// lhs mod c -> 0 if lhs % c = 0.
721+
if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
722+
binExpr.getKind() == AffineExprKind::FloorDiv) ||
723+
(getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
724+
binExpr.getKind() == AffineExprKind::Mod)) {
725+
expr = getAffineConstantExpr(0, expr.getContext());
726+
}
727+
}
728+
729+
/// Simplify the map while exploiting information on the values in `operands`.
730+
// Use "unused attribute" marker to silence warning stemming from the inability
731+
// to see through the template expansion.
732+
static void LLVM_ATTRIBUTE_UNUSED
733+
simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
734+
assert(map.getNumInputs() == operands.size() && "invalid operands for map");
735+
SmallVector<AffineExpr> newResults;
736+
newResults.reserve(map.getNumResults());
737+
for (AffineExpr expr : map.getResults()) {
738+
simplifyExprAndOperands(expr, operands);
739+
newResults.push_back(expr);
740+
}
741+
map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
742+
map.getContext());
743+
}
744+
581745
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
582746
/// defining AffineApplyOp expression and operands.
583747
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1095,6 +1259,7 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
10951259
SmallVector<Value, 8> resultOperands(oldOperands);
10961260
composeAffineMapAndOperands(&map, &resultOperands);
10971261
canonicalizeMapAndOperands(&map, &resultOperands);
1262+
simplifyMapWithOperands(map, resultOperands);
10981263
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
10991264
resultOperands.begin()))
11001265
return failure();

mlir/lib/IR/AffineMap.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/ADT/SmallSet.h"
1717
#include "llvm/ADT/StringRef.h"
1818
#include "llvm/Support/raw_ostream.h"
19+
#include <numeric>
1920

2021
using namespace mlir;
2122

@@ -241,6 +242,17 @@ AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
241242
return ::inferFromExprList(exprsList);
242243
}
243244

245+
uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() {
246+
uint64_t gcd = 0;
247+
for (AffineExpr resultExpr : getResults()) {
248+
uint64_t thisGcd = resultExpr.getLargestKnownDivisor();
249+
gcd = std::gcd(gcd, thisGcd);
250+
}
251+
if (gcd == 0)
252+
gcd = std::numeric_limits<uint64_t>::max();
253+
return gcd;
254+
}
255+
244256
AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
245257
MLIRContext *context) {
246258
SmallVector<AffineExpr, 4> dimExprs;

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ func.func @compose_affine_maps_2d_tile(%0: memref<16x32xf32>, %1: memref<16x32xf
9898
%c4 = arith.constant 4 : index
9999
%c8 = arith.constant 8 : index
100100

101-
affine.for %i0 = 0 to 3 {
101+
affine.for %i0 = 0 to 16 {
102102
%x0 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)> (%i0)[%c4]
103-
affine.for %i1 = 0 to 3 {
103+
affine.for %i1 = 0 to 16 {
104104
%x1 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)> (%i1)[%c8]
105-
affine.for %i2 = 0 to 3 {
105+
affine.for %i2 = 0 to 16 {
106106
%x2 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)> (%i2)[%c4]
107-
affine.for %i3 = 0 to 3 {
107+
affine.for %i3 = 0 to 16 {
108108
%x3 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)> (%i3)[%c8]
109109

110110
%x40 = affine.apply affine_map<(d0, d1, d2, d3)[s0, s1] ->
@@ -1150,3 +1150,41 @@ module {
11501150
return %s: memref<32x64xf32>
11511151
}
11521152
}
1153+
1154+
// -----
1155+
1156+
// Simplification of maps exploiting operand info.
1157+
1158+
// CHECK-LABEL: func @simplify_with_operands
1159+
func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
1160+
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to %{{.*}}
1161+
affine.for %i = 0 to %N step 32 {
1162+
// CHECK-NEXT: affine.for %[[II:.*]] = 0 to 32
1163+
affine.for %ii = 0 to 32 {
1164+
// %ii is less than 32 and %i divides 32.
1165+
// CHECK: affine.load %{{.*}}[0, 0]
1166+
%x = affine.load %A[%ii floordiv 32, %i mod 32] : memref<?x32xf32>
1167+
"test.foo"(%x) : (f32) -> ()
1168+
1169+
// %i is aligned at 32 boundary and %ii < 32.
1170+
// CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32]
1171+
%a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref<?x32xf32>
1172+
"test.foo"(%a) : (f32) -> ()
1173+
// CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64]
1174+
%b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref<?x32xf32>
1175+
"test.foo"(%b) : (f32) -> ()
1176+
// CHECK: affine.load %{{.*}}[(%[[I]] + %[[II]]) floordiv 16, %[[II]] mod 16]
1177+
%c = affine.load %A[(%i + %ii) floordiv 16, (%i + %ii) mod 16] : memref<?x32xf32>
1178+
"test.foo"(%c) : (f32) -> ()
1179+
}
1180+
}
1181+
1182+
// Should not simplify.
1183+
affine.for %i = -1 to 32 {
1184+
// CHECK: affine.load %{{.*}}[%{{.*}} floordiv {{.*}}, %{{.*}} mod {{.*}}] :
1185+
%x = affine.load %A[%i floordiv 32, %i mod 32] : memref<?x32xf32>
1186+
"test.foo"(%x) : (f32) -> ()
1187+
}
1188+
1189+
return
1190+
}

0 commit comments

Comments
 (0)