@@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1177
1177
if (flatExprs[numDims + numSymbols + it.index ()] == 0 )
1178
1178
continue ;
1179
1179
AffineExpr expr = it.value ();
1180
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181
- if (!binaryExpr)
1182
- continue ;
1183
-
1180
+ // Local expression cannot be a dimension, symbol or a constant -- it
1181
+ // should be a binary op expression.
1182
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
1184
1183
AffineExpr lhs = binaryExpr.getLHS ();
1185
1184
AffineExpr rhs = binaryExpr.getRHS ();
1186
1185
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1347,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1348
1347
AffineExpr divisorExpr = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1349
1348
localExprs, context);
1350
1349
AffineExpr modExpr = dividendExpr % divisorExpr;
1350
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1351
+ std::fill (lhs.begin (), lhs.end (), 0 );
1352
+ lhs[getConstantIndex ()] = constModExpr.getValue ();
1353
+ return success ();
1354
+ }
1355
+ if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
1356
+ std::fill (lhs.begin (), lhs.end (), 0 );
1357
+ lhs[getDimStartIndex () + dimModExpr.getPosition ()] = 1 ;
1358
+ return success ();
1359
+ }
1360
+ if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
1361
+ std::fill (lhs.begin (), lhs.end (), 0 );
1362
+ lhs[getSymbolStartIndex () + symbolModExpr.getPosition ()] = 1 ;
1363
+ return success ();
1364
+ }
1351
1365
return addLocalVariableSemiAffine (modLhs, rhs, modExpr, lhs, lhs.size ());
1352
1366
}
1353
1367
@@ -1482,6 +1496,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1482
1496
AffineExpr b = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1483
1497
localExprs, context);
1484
1498
AffineExpr divExpr = isCeil ? a.ceilDiv (b) : a.floorDiv (b);
1499
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1500
+ std::fill (lhs.begin (), lhs.end (), 0 );
1501
+ lhs[getConstantIndex ()] = constDivExpr.getValue ();
1502
+ return success ();
1503
+ }
1504
+ if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
1505
+ std::fill (lhs.begin (), lhs.end (), 0 );
1506
+ lhs[getDimStartIndex () + dimDivExpr.getPosition ()] = 1 ;
1507
+ return success ();
1508
+ }
1509
+ if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
1510
+ std::fill (lhs.begin (), lhs.end (), 0 );
1511
+ lhs[getSymbolStartIndex () + symbolDivExpr.getPosition ()] = 1 ;
1512
+ return success ();
1513
+ }
1485
1514
return addLocalVariableSemiAffine (divLhs, rhs, divExpr, lhs, lhs.size ());
1486
1515
}
1487
1516
0 commit comments