Skip to content

Commit 454e4e3

Browse files
authored
[mlir][AffineExpr] Order arguments in the commutative affine exprs (#146895)
Order symbol/dim arguments by position and put dims before symbols. This is to help affine simplifications.
1 parent 1f3f987 commit 454e4e3

File tree

4 files changed

+59
-5
lines changed

4 files changed

+59
-5
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,16 +793,45 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
793793
return nullptr;
794794
}
795795

796+
/// Get the canonical order of two commutative exprs arguments.
797+
static std::pair<AffineExpr, AffineExpr>
798+
orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
799+
auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
800+
auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
801+
// Try to order by symbol/dim position first.
802+
if (sym1 && sym2)
803+
return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
804+
: std::pair{expr2, expr1};
805+
806+
auto dim1 = dyn_cast<AffineDimExpr>(expr1);
807+
auto dim2 = dyn_cast<AffineDimExpr>(expr2);
808+
if (dim1 && dim2)
809+
return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
810+
: std::pair{expr2, expr1};
811+
812+
// Put dims before symbols.
813+
if (dim1 && sym2)
814+
return {dim1, sym2};
815+
816+
if (sym1 && dim2)
817+
return {dim2, sym1};
818+
819+
// Otherwise, keep original order.
820+
return {expr1, expr2};
821+
}
822+
796823
AffineExpr AffineExpr::operator+(int64_t v) const {
797824
return *this + getAffineConstantExpr(v, getContext());
798825
}
799826
AffineExpr AffineExpr::operator+(AffineExpr other) const {
800827
if (auto simplified = simplifyAdd(*this, other))
801828
return simplified;
802829

830+
auto [lhs, rhs] = orderCommutativeArgs(*this, other);
831+
803832
StorageUniquer &uniquer = getContext()->getAffineUniquer();
804833
return uniquer.get<AffineBinaryOpExprStorage>(
805-
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
834+
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
806835
}
807836

808837
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
@@ -865,9 +894,11 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
865894
if (auto simplified = simplifyMul(*this, other))
866895
return simplified;
867896

897+
auto [lhs, rhs] = orderCommutativeArgs(*this, other);
898+
868899
StorageUniquer &uniquer = getContext()->getAffineUniquer();
869900
return uniquer.get<AffineBinaryOpExprStorage>(
870-
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
901+
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
871902
}
872903

873904
// Unary minus, delegate to operator*.

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ func.func @test_not_trivially_true_or_false_returning_three_results() -> (index,
508508
// -----
509509

510510
// Test simplification of mod expressions.
511-
// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)>
511+
// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s4 + s3 + (s0 - s1) mod s2)>
512512
// CHECK-DAG: #[[$SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))>
513513
// CHECK-DAG: #[[$MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)>
514514
// CHECK-LABEL: func @semiaffine_simplification_mod
@@ -547,7 +547,7 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: i
547547

548548
// Test simplification of product expressions.
549549
// CHECK-DAG: #[[$PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)>
550-
// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 + s2 * s0 + s3 + s3 * s0 + s3 * s1 + s4 + s4 * s1)>
550+
// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s2 + s0 * s3 + s1 * s3 + s1 * s4 + s2 + s3 + s4)>
551551
// CHECK-LABEL: func @semiaffine_simplification_product
552552
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
553553
func.func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) {

mlir/test/IR/affine-map.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
#map44 = affine_map<(i, j) -> (i - 2*j, j * 6 floordiv 4)>
140140

141141
// Simplifications
142-
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)>
142+
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d1 + d2, (d0 * s0) * 8)>
143143
#map45 = affine_map<(i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)>
144144

145145
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (0, d1, d0 * 2, 0)>

mlir/unittests/IR/AffineExprTest.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ TEST(AffineExprTest, constantFolding) {
8484
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
8585
}
8686

87+
TEST(AffineExprTest, commutative) {
88+
MLIRContext ctx;
89+
OpBuilder b(&ctx);
90+
auto c2 = b.getAffineConstantExpr(1);
91+
auto d0 = b.getAffineDimExpr(0);
92+
auto d1 = b.getAffineDimExpr(1);
93+
auto s0 = b.getAffineSymbolExpr(0);
94+
auto s1 = b.getAffineSymbolExpr(1);
95+
96+
ASSERT_EQ(d0 * d1, d1 * d0);
97+
ASSERT_EQ(s0 + s1, s1 + s0);
98+
ASSERT_EQ(s0 * c2, c2 * s0);
99+
}
100+
87101
TEST(AffineExprTest, divisionSimplification) {
88102
MLIRContext ctx;
89103
OpBuilder b(&ctx);
@@ -147,3 +161,12 @@ TEST(AffineExprTest, simpleAffineExprFlattenerRegression) {
147161
ASSERT_TRUE(isa<AffineConstantExpr>(result));
148162
ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7);
149163
}
164+
165+
TEST(AffineExprTest, simplifyCommutative) {
166+
MLIRContext ctx;
167+
OpBuilder b(&ctx);
168+
auto s0 = b.getAffineSymbolExpr(0);
169+
auto s1 = b.getAffineSymbolExpr(1);
170+
171+
ASSERT_EQ(s0 * s1 - s1 * s0 + 1, 1);
172+
}

0 commit comments

Comments
 (0)