|
19 | 19 | #include "mlir/Support/TypeID.h"
|
20 | 20 | #include "llvm/ADT/STLExtras.h"
|
21 | 21 | #include "llvm/Support/MathExtras.h"
|
| 22 | +#include "llvm/Support/raw_ostream.h" |
22 | 23 | #include <numeric>
|
23 | 24 | #include <optional>
|
24 | 25 |
|
@@ -1177,10 +1178,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
|
1177 | 1178 | if (flatExprs[numDims + numSymbols + it.index()] == 0)
|
1178 | 1179 | continue;
|
1179 | 1180 | AffineExpr expr = it.value();
|
1180 |
| - auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr); |
1181 |
| - if (!binaryExpr) |
1182 |
| - continue; |
1183 |
| - |
| 1181 | + // A Local expression cannot be a dimension, symbol or a constant -- it |
| 1182 | + // should be a binary op expression. |
| 1183 | + auto binaryExpr = cast<AffineBinaryOpExpr>(expr); |
1184 | 1184 | AffineExpr lhs = binaryExpr.getLHS();
|
1185 | 1185 | AffineExpr rhs = binaryExpr.getRHS();
|
1186 | 1186 | if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
|
@@ -1295,7 +1295,23 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
|
1295 | 1295 | localExprs, context);
|
1296 | 1296 | AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
1297 | 1297 | localExprs, context);
|
1298 |
| - return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size()); |
| 1298 | + AffineExpr mulExpr = a * b; |
| 1299 | + if (auto constMulExpr = dyn_cast<AffineConstantExpr>(mulExpr)) { |
| 1300 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1301 | + lhs[getConstantIndex()] = constMulExpr.getValue(); |
| 1302 | + return success(); |
| 1303 | + } |
| 1304 | + if (auto dimMulExpr = dyn_cast<AffineDimExpr>(mulExpr)) { |
| 1305 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1306 | + lhs[getDimStartIndex() + dimMulExpr.getPosition()] = 1; |
| 1307 | + return success(); |
| 1308 | + } |
| 1309 | + if (auto symbolMulExpr = dyn_cast<AffineSymbolExpr>(mulExpr)) { |
| 1310 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1311 | + lhs[getSymbolStartIndex() + symbolMulExpr.getPosition()] = 1; |
| 1312 | + return success(); |
| 1313 | + } |
| 1314 | + return addLocalVariableSemiAffine(mulLhs, rhs, mulExpr, lhs, lhs.size()); |
1299 | 1315 | }
|
1300 | 1316 |
|
1301 | 1317 | // Get the RHS constant.
|
@@ -1348,6 +1364,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
1348 | 1364 | AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
1349 | 1365 | localExprs, context);
|
1350 | 1366 | AffineExpr modExpr = dividendExpr % divisorExpr;
|
| 1367 | + if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) { |
| 1368 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1369 | + lhs[getConstantIndex()] = constModExpr.getValue(); |
| 1370 | + return success(); |
| 1371 | + } |
| 1372 | + if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) { |
| 1373 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1374 | + lhs[getDimStartIndex() + dimModExpr.getPosition()] = 1; |
| 1375 | + return success(); |
| 1376 | + } |
| 1377 | + if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) { |
| 1378 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1379 | + lhs[getSymbolStartIndex() + symbolModExpr.getPosition()] = 1; |
| 1380 | + return success(); |
| 1381 | + } |
1351 | 1382 | return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
|
1352 | 1383 | }
|
1353 | 1384 |
|
@@ -1482,6 +1513,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
1482 | 1513 | AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
1483 | 1514 | localExprs, context);
|
1484 | 1515 | AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
|
| 1516 | + if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) { |
| 1517 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1518 | + lhs[getConstantIndex()] = constDivExpr.getValue(); |
| 1519 | + return success(); |
| 1520 | + } |
| 1521 | + if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) { |
| 1522 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1523 | + lhs[getDimStartIndex() + dimDivExpr.getPosition()] = 1; |
| 1524 | + return success(); |
| 1525 | + } |
| 1526 | + if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) { |
| 1527 | + std::fill(lhs.begin(), lhs.end(), 0); |
| 1528 | + lhs[getSymbolStartIndex() + symbolDivExpr.getPosition()] = 1; |
| 1529 | + return success(); |
| 1530 | + } |
1485 | 1531 | return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
|
1486 | 1532 | }
|
1487 | 1533 |
|
@@ -1574,6 +1620,7 @@ int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
|
1574 | 1620 | AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
|
1575 | 1621 | unsigned numSymbols) {
|
1576 | 1622 | // Simplify semi-affine expressions separately.
|
| 1623 | + expr.dump(); |
1577 | 1624 | if (!expr.isPureAffine())
|
1578 | 1625 | expr = simplifySemiAffine(expr, numDims, numSymbols);
|
1579 | 1626 |
|
|
0 commit comments