Skip to content

Commit 9792917

Browse files
committed
Lower affine modulo by powers of two using bitwise AND
1 parent bb982e7 commit 9792917

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,24 @@ class AffineApplyExpander
8080
/// let remainder = srem a, b;
8181
/// negative = a < 0 in
8282
/// select negative, remainder + b, remainder.
83+
///
84+
/// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative x.
8385
Value visitModExpr(AffineBinaryOpExpr expr) {
8486
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
8587
if (rhsConst.getValue() <= 0) {
8688
emitError(loc, "modulo by non-positive value is not supported");
8789
return nullptr;
8890
}
91+
92+
// Special case: x mod n where n is a power of 2 can be optimized to x & (n-1)
93+
int64_t rhsValue = rhsConst.getValue();
94+
if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
95+
auto lhs = visit(expr.getLHS());
96+
assert(lhs && "unexpected affine expr lowering failure");
97+
98+
Value maskCst = builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
99+
return builder.create<arith::AndIOp>(loc, lhs, maskCst);
100+
}
89101
}
90102

91103
auto lhs = visit(expr.getLHS());

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,12 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
927927
// CHECK: scf.reduce.return %[[RES]] : i64
928928
// CHECK: }
929929
// CHECK: }
930+
931+
#map_mod_8 = affine_map<(i) -> (i mod 8)>
932+
// CHECK-LABEL: func @affine_apply_mod_8
933+
func.func @affine_apply_mod_8(%arg0 : index) -> (index) {
934+
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
935+
// CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index
936+
%0 = affine.apply #map_mod_8 (%arg0)
937+
return %0 : index
938+
}

0 commit comments

Comments
 (0)