Skip to content

Commit 5fc2915

Browse files
Fix bug in visitDivExpr, visitMulExpr and visitModExpr
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
1 parent 529662a commit 5fc2915

File tree

2 files changed

+76
-12
lines changed

2 files changed

+76
-12
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Support/TypeID.h"
2020
#include "llvm/ADT/STLExtras.h"
2121
#include "llvm/Support/MathExtras.h"
22+
#include "llvm/Support/raw_ostream.h"
2223
#include <numeric>
2324
#include <optional>
2425

@@ -1177,10 +1178,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
11771178
if (flatExprs[numDims + numSymbols + it.index()] == 0)
11781179
continue;
11791180
AffineExpr expr = it.value();
1180-
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181-
if (!binaryExpr)
1182-
continue;
1183-
1181+
// A Local expression cannot be a dimension, symbol or a constant -- it
1182+
// should be a binary op expression.
1183+
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
11841184
AffineExpr lhs = binaryExpr.getLHS();
11851185
AffineExpr rhs = binaryExpr.getRHS();
11861186
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1295,7 +1295,23 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
12951295
localExprs, context);
12961296
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
12971297
localExprs, context);
1298-
return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1298+
AffineExpr mulExpr = a * b;
1299+
if (auto constMulExpr = dyn_cast<AffineConstantExpr>(mulExpr)) {
1300+
std::fill(lhs.begin(), lhs.end(), 0);
1301+
lhs[getConstantIndex()] = constMulExpr.getValue();
1302+
return success();
1303+
}
1304+
if (auto dimMulExpr = dyn_cast<AffineDimExpr>(mulExpr)) {
1305+
std::fill(lhs.begin(), lhs.end(), 0);
1306+
lhs[getDimStartIndex() + dimMulExpr.getPosition()] = 1;
1307+
return success();
1308+
}
1309+
if (auto symbolMulExpr = dyn_cast<AffineSymbolExpr>(mulExpr)) {
1310+
std::fill(lhs.begin(), lhs.end(), 0);
1311+
lhs[getSymbolStartIndex() + symbolMulExpr.getPosition()] = 1;
1312+
return success();
1313+
}
1314+
return addLocalVariableSemiAffine(mulLhs, rhs, mulExpr, lhs, lhs.size());
12991315
}
13001316

13011317
// Get the RHS constant.
@@ -1348,6 +1364,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13481364
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
13491365
localExprs, context);
13501366
AffineExpr modExpr = dividendExpr % divisorExpr;
1367+
if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1368+
std::fill(lhs.begin(), lhs.end(), 0);
1369+
lhs[getConstantIndex()] = constModExpr.getValue();
1370+
return success();
1371+
}
1372+
if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
1373+
std::fill(lhs.begin(), lhs.end(), 0);
1374+
lhs[getDimStartIndex() + dimModExpr.getPosition()] = 1;
1375+
return success();
1376+
}
1377+
if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
1378+
std::fill(lhs.begin(), lhs.end(), 0);
1379+
lhs[getSymbolStartIndex() + symbolModExpr.getPosition()] = 1;
1380+
return success();
1381+
}
13511382
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
13521383
}
13531384

@@ -1482,6 +1513,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
14821513
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
14831514
localExprs, context);
14841515
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1516+
if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1517+
std::fill(lhs.begin(), lhs.end(), 0);
1518+
lhs[getConstantIndex()] = constDivExpr.getValue();
1519+
return success();
1520+
}
1521+
if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
1522+
std::fill(lhs.begin(), lhs.end(), 0);
1523+
lhs[getDimStartIndex() + dimDivExpr.getPosition()] = 1;
1524+
return success();
1525+
}
1526+
if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
1527+
std::fill(lhs.begin(), lhs.end(), 0);
1528+
lhs[getSymbolStartIndex() + symbolDivExpr.getPosition()] = 1;
1529+
return success();
1530+
}
14851531
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
14861532
}
14871533

@@ -1574,6 +1620,7 @@ int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
15741620
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
15751621
unsigned numSymbols) {
15761622
// Simplify semi-affine expressions separately.
1623+
expr.dump();
15771624
if (!expr.isPureAffine())
15781625
expr = simplifySemiAffine(expr, numDims, numSymbols);
15791626

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -595,16 +595,33 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
595595

596596
// -----
597597

598-
// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
599-
func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
598+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)>
599+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)>
600+
// CHECK-LABEL: semiaffine_simplification_local_expr_folded_into_non_binary_expr
601+
func.func @semiaffine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index, index, index) {
600602
%c0 = arith.constant 0 : index
601603
%c1 = arith.constant 1 : index
604+
%c4 = arith.constant 4 : index
602605
%c13 = arith.constant 13 : index
603-
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
606+
// CHECK: %[[DIM:.*]] = memref.dim
607+
%dim = memref.dim %arg0, %c0 : memref<?x?xf32>
608+
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
609+
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
604610
%a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
605611
%b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
606-
// CHECK: %[[C6:.*]] = arith.constant 6 : index
607-
// CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
608-
// CHECK-NEXT: return %[[C6]], %[[C7]]
609-
return %a, %b : index, index
612+
// CHECK: %0 = affine.apply #[[$MAP]]()[%[[DIM]]]
613+
%c = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 mod (s1 + (-s1 + s3) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1, %dim]
614+
%alloc = memref.alloc() : memref<1xindex>
615+
affine.for %iv = 0 to 1 {
616+
%d = affine.apply affine_map<(d0)[s1, s2] -> ((d0 - s1 + s1 * s2) * (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>(%iv)[%dim, %c1]
617+
affine.store %d, %alloc[0] : memref<1xindex>
618+
}
619+
// CHECK: affine.for %[[IV:.*]] = 0 to 1 {
620+
// CHECK-NEXT: %[[VAL:.*]] = affine.apply #[[$MAP1]](%[[IV]])
621+
// CHECK-NEXT: affine.store %[[VAL]], %{{.*}}[0] : memref<1xindex>
622+
// CHECK-NEXT: }
623+
// CHECK: %[[VAL1:.*]] = affine.load %{{.*}}[0]
624+
%d = affine.load %alloc[0] : memref<1xindex>
625+
626+
return %a, %b, %c, %d : index, index, index, index
610627
}

0 commit comments

Comments
 (0)