Skip to content

Commit f7715ad

Browse files
Fix bug in visitDivExpr, visitMulExpr and visitModExpr
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
1 parent 529662a commit f7715ad

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

mlir/include/mlir/IR/AffineExprVisitor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,14 @@ class SimpleAffineExprFlattener
418418
AffineExpr localExpr);
419419

420420
private:
421+
/// Flatten binary expression `expr` and add it to `result`. If `expr` is a
422+
/// dimension, symbol or constant, we add it to appropriate index in `result`.
423+
/// Otherwise we add it in the local variable section. `lhs` and `rhs` are the
424+
/// operands of `expr`.
425+
LogicalResult addExprToFlattenedList(AffineExpr expr, ArrayRef<int64_t> lhs,
426+
ArrayRef<int64_t> rhs,
427+
SmallVectorImpl<int64_t> &result);
428+
421429
/// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression
422430
/// representing the affine expression corresponding to the quantifier
423431
/// introduced as the local variable corresponding to `localExpr`. If the

mlir/lib/IR/AffineExpr.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
11771177
if (flatExprs[numDims + numSymbols + it.index()] == 0)
11781178
continue;
11791179
AffineExpr expr = it.value();
1180-
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181-
if (!binaryExpr)
1182-
continue;
1183-
1180+
// A local expression cannot be a dimension, symbol or a constant -- it
1181+
// should be a binary op expression.
1182+
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
11841183
AffineExpr lhs = binaryExpr.getLHS();
11851184
AffineExpr rhs = binaryExpr.getRHS();
11861185
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1274,6 +1273,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
12741273
operandExprStack.reserve(8);
12751274
}
12761275

1276+
LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList(
1277+
AffineExpr expr, ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
1278+
SmallVectorImpl<int64_t> &result) {
1279+
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
1280+
std::fill(result.begin(), result.end(), 0);
1281+
result[getConstantIndex()] = constExpr.getValue();
1282+
return success();
1283+
}
1284+
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1285+
std::fill(result.begin(), result.end(), 0);
1286+
result[getDimStartIndex() + dimExpr.getPosition()] = 1;
1287+
return success();
1288+
}
1289+
if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
1290+
std::fill(result.begin(), result.end(), 0);
1291+
result[getSymbolStartIndex() + symExpr.getPosition()] = 1;
1292+
return success();
1293+
}
1294+
return addLocalVariableSemiAffine(lhs, rhs, expr, result, result.size());
1295+
}
1296+
12771297
// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
12781298
//
12791299
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
@@ -1295,7 +1315,7 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
12951315
localExprs, context);
12961316
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
12971317
localExprs, context);
1298-
return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1318+
return addExprToFlattenedList(a * b, mulLhs, rhs, lhs);
12991319
}
13001320

13011321
// Get the RHS constant.
@@ -1347,8 +1367,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13471367
lhs, numDims, numSymbols, localExprs, context);
13481368
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
13491369
localExprs, context);
1350-
AffineExpr modExpr = dividendExpr % divisorExpr;
1351-
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
1370+
return addExprToFlattenedList(dividendExpr % divisorExpr, modLhs, rhs, lhs);
13521371
}
13531372

13541373
int64_t rhsConst = rhs[getConstantIndex()];
@@ -1482,7 +1501,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
14821501
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
14831502
localExprs, context);
14841503
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1485-
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
1504+
return addExprToFlattenedList(divExpr, divLhs, rhs, lhs);
14861505
}
14871506

14881507
// This is a pure affine expr; the RHS is a positive constant.

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,32 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?x
608608
// CHECK-NEXT: return %[[C6]], %[[C7]]
609609
return %a, %b : index, index
610610
}
611+
612+
// -----
613+
614+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)>
615+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)>
616+
// CHECK-LABEL: semiaffine_simplification_local_expr_folded_into_non_binary_expr
617+
func.func @semiaffine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index) {
618+
%c0 = arith.constant 0 : index
619+
%c1 = arith.constant 1 : index
620+
%c4 = arith.constant 4 : index
621+
%c13 = arith.constant 13 : index
622+
// CHECK: %[[DIM:.*]] = memref.dim
623+
%dim = memref.dim %arg0, %c0 : memref<?x?xf32>
624+
// CHECK: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[DIM]]]
625+
%c = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 mod (s1 + (-s1 + s3) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1, %dim]
626+
%alloc = memref.alloc() : memref<1xindex>
627+
affine.for %iv = 0 to 1 {
628+
%d = affine.apply affine_map<(d0)[s1, s2] -> ((d0 - s1 + s1 * s2) * (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>(%iv)[%dim, %c1]
629+
affine.store %d, %alloc[0] : memref<1xindex>
630+
}
631+
// CHECK: affine.for %[[IV:.*]] = 0 to 1 {
632+
// CHECK-NEXT: %[[VAL:.*]] = affine.apply #[[$MAP1]](%[[IV]])
633+
// CHECK-NEXT: affine.store %[[VAL]], %{{.*}}[0] : memref<1xindex>
634+
// CHECK-NEXT: }
635+
// CHECK: %[[VAL1:.*]] = affine.load %{{.*}}[0]
636+
%d = affine.load %alloc[0] : memref<1xindex>
637+
// CHECK: return %[[VAL0]], %[[VAL1]]
638+
return %c, %d : index, index
639+
}

0 commit comments

Comments
 (0)