Skip to content

Commit 75ab97d

Browse files
Fix bug in visitDivExpr and visitModExpr
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
1 parent 529662a commit 75ab97d

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
11771177
if (flatExprs[numDims + numSymbols + it.index()] == 0)
11781178
continue;
11791179
AffineExpr expr = it.value();
1180-
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181-
if (!binaryExpr)
1182-
continue;
1183-
1180+
// Local expression cannot be a dimension, symbol or a constant -- it
1181+
// should be a binary op expression.
1182+
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
11841183
AffineExpr lhs = binaryExpr.getLHS();
11851184
AffineExpr rhs = binaryExpr.getRHS();
11861185
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1347,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13481347
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
13491348
localExprs, context);
13501349
AffineExpr modExpr = dividendExpr % divisorExpr;
1350+
if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1351+
std::fill(lhs.begin(), lhs.end(), 0);
1352+
lhs[getConstantIndex()] = constModExpr.getValue();
1353+
return success();
1354+
}
1355+
if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
1356+
std::fill(lhs.begin(), lhs.end(), 0);
1357+
lhs[getDimStartIndex() + dimModExpr.getPosition()] = 1;
1358+
return success();
1359+
}
1360+
if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
1361+
std::fill(lhs.begin(), lhs.end(), 0);
1362+
lhs[getSymbolStartIndex() + symbolModExpr.getPosition()] = 1;
1363+
return success();
1364+
}
13511365
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
13521366
}
13531367

@@ -1482,6 +1496,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
14821496
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
14831497
localExprs, context);
14841498
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1499+
if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1500+
std::fill(lhs.begin(), lhs.end(), 0);
1501+
lhs[getConstantIndex()] = constDivExpr.getValue();
1502+
return success();
1503+
}
1504+
if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
1505+
std::fill(lhs.begin(), lhs.end(), 0);
1506+
lhs[getDimStartIndex() + dimDivExpr.getPosition()] = 1;
1507+
return success();
1508+
}
1509+
if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
1510+
std::fill(lhs.begin(), lhs.end(), 0);
1511+
lhs[getSymbolStartIndex() + symbolDivExpr.getPosition()] = 1;
1512+
return success();
1513+
}
14851514
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
14861515
}
14871516

0 commit comments

Comments
 (0)