|
21 | 21 | #include "llvm/ADT/SmallBitVector.h"
|
22 | 22 | #include "llvm/ADT/TypeSwitch.h"
|
23 | 23 | #include "llvm/Support/Debug.h"
|
| 24 | +#include <numeric> |
24 | 25 |
|
25 | 26 | using namespace mlir;
|
26 | 27 |
|
@@ -578,6 +579,169 @@ OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
|
578 | 579 | return result[0];
|
579 | 580 | }
|
580 | 581 |
|
| 582 | +/// Returns the largest known divisor of `e`. Exploits information from the |
| 583 | +/// values in `operands`. |
| 584 | +static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) { |
| 585 | + // This method isn't aware of `operands`. |
| 586 | + int64_t div = e.getLargestKnownDivisor(); |
| 587 | + |
| 588 | + // We now make use of operands for the case `e` is a dim expression. |
| 589 | + // TODO: More powerful simplification would have to modify |
| 590 | + // getLargestKnownDivisor to take `operands` and exploit that information as |
| 591 | + // well for dim/sym expressions, but in that case, getLargestKnownDivisor |
| 592 | + // can't be part of the IR library but of the `Analysis` library. The IR |
| 593 | + // library can only really depend on simple O(1) checks. |
| 594 | + auto dimExpr = e.dyn_cast<AffineDimExpr>(); |
| 595 | + // If it's not a dim expr, `div` is the best we have. |
| 596 | + if (!dimExpr) |
| 597 | + return div; |
| 598 | + |
| 599 | + // We simply exploit information from loop IVs. |
| 600 | + // We don't need to use mlir::getLargestKnownDivisorOfValue since the other |
| 601 | + // desired simplifications are expected to be part of other |
| 602 | + // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the |
| 603 | + // LoopAnalysis library. |
| 604 | + Value operand = operands[dimExpr.getPosition()]; |
| 605 | + int64_t operandDivisor = 1; |
| 606 | + // TODO: With the right accessors, this can be extended to |
| 607 | + // LoopLikeOpInterface. |
| 608 | + if (AffineForOp forOp = getForInductionVarOwner(operand)) { |
| 609 | + if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) { |
| 610 | + operandDivisor = forOp.getStep(); |
| 611 | + } else { |
| 612 | + uint64_t lbLargestKnownDivisor = |
| 613 | + forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs(); |
| 614 | + operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStep()); |
| 615 | + } |
| 616 | + } |
| 617 | + return operandDivisor; |
| 618 | +} |
| 619 | + |
| 620 | +/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e` |
| 621 | +/// being an affine dim expression or a constant. |
| 622 | +static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands, |
| 623 | + int64_t k) { |
| 624 | + if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) { |
| 625 | + int64_t constVal = constExpr.getValue(); |
| 626 | + return constVal >= 0 && constVal < k; |
| 627 | + } |
| 628 | + auto dimExpr = e.dyn_cast<AffineDimExpr>(); |
| 629 | + if (!dimExpr) |
| 630 | + return false; |
| 631 | + Value operand = operands[dimExpr.getPosition()]; |
| 632 | + // TODO: With the right accessors, this can be extended to |
| 633 | + // LoopLikeOpInterface. |
| 634 | + if (AffineForOp forOp = getForInductionVarOwner(operand)) { |
| 635 | + if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 && |
| 636 | + forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) { |
| 637 | + return true; |
| 638 | + } |
| 639 | + } |
| 640 | + |
| 641 | + // We don't consider other cases like `operand` being defined by a constant or |
| 642 | + // an affine.apply op since such cases will already be handled by other |
| 643 | + // patterns and propagation of loop IVs or constant would happen. |
| 644 | + return false; |
| 645 | +} |
| 646 | + |
| 647 | +/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d. |
| 648 | +/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the |
| 649 | +/// expression is in that form. |
| 650 | +static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div, |
| 651 | + AffineExpr "ientTimesDiv, AffineExpr &rem) { |
| 652 | + auto bin = e.dyn_cast<AffineBinaryOpExpr>(); |
| 653 | + if (!bin || bin.getKind() != AffineExprKind::Add) |
| 654 | + return false; |
| 655 | + |
| 656 | + AffineExpr llhs = bin.getLHS(); |
| 657 | + AffineExpr rlhs = bin.getRHS(); |
| 658 | + div = getLargestKnownDivisor(llhs, operands); |
| 659 | + if (isNonNegativeBoundedBy(rlhs, operands, div)) { |
| 660 | + quotientTimesDiv = llhs; |
| 661 | + rem = rlhs; |
| 662 | + return true; |
| 663 | + } |
| 664 | + div = getLargestKnownDivisor(rlhs, operands); |
| 665 | + if (isNonNegativeBoundedBy(llhs, operands, div)) { |
| 666 | + quotientTimesDiv = rlhs; |
| 667 | + rem = llhs; |
| 668 | + return true; |
| 669 | + } |
| 670 | + return false; |
| 671 | +} |
| 672 | + |
| 673 | +/// Simplify `expr` while exploiting information from the values in `operands`. |
| 674 | +static void simplifyExprAndOperands(AffineExpr &expr, |
| 675 | + ArrayRef<Value> operands) { |
| 676 | + // We do this only for certain floordiv/mod expressions. |
| 677 | + auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>(); |
| 678 | + if (!binExpr) |
| 679 | + return; |
| 680 | + |
| 681 | + // Simplify the child expressions first. |
| 682 | + auto lhs = binExpr.getLHS(); |
| 683 | + auto rhs = binExpr.getRHS(); |
| 684 | + simplifyExprAndOperands(lhs, operands); |
| 685 | + simplifyExprAndOperands(rhs, operands); |
| 686 | + expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs); |
| 687 | + |
| 688 | + binExpr = expr.dyn_cast<AffineBinaryOpExpr>(); |
| 689 | + if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv && |
| 690 | + binExpr.getKind() != AffineExprKind::Mod)) { |
| 691 | + return; |
| 692 | + } |
| 693 | + |
| 694 | + auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); |
| 695 | + if (!rhsConst) |
| 696 | + return; |
| 697 | + |
| 698 | + int64_t rhsConstVal = rhsConst.getValue(); |
| 699 | + AffineExpr quotientTimesDiv, rem; |
| 700 | + int64_t divisor; |
| 701 | + |
| 702 | + // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2) |
| 703 | + // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if |
| 704 | + // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c. |
| 705 | + // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c. |
| 706 | + if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) { |
| 707 | + if (rhsConstVal % divisor == 0 && |
| 708 | + binExpr.getKind() == AffineExprKind::FloorDiv) { |
| 709 | + expr = quotientTimesDiv.floorDiv(rhsConst); |
| 710 | + } else if (divisor % rhsConstVal == 0 && |
| 711 | + binExpr.getKind() == AffineExprKind::Mod) { |
| 712 | + expr = rem % rhsConst; |
| 713 | + } |
| 714 | + return; |
| 715 | + } |
| 716 | + |
| 717 | + // Handle the simple case when the LHS expression can be either upper |
| 718 | + // bounded or is a known multiple of RHS constant. |
| 719 | + // lhs floordiv c -> 0 if 0 <= lhs < c, |
| 720 | + // lhs mod c -> 0 if lhs % c = 0. |
| 721 | + if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) && |
| 722 | + binExpr.getKind() == AffineExprKind::FloorDiv) || |
| 723 | + (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 && |
| 724 | + binExpr.getKind() == AffineExprKind::Mod)) { |
| 725 | + expr = getAffineConstantExpr(0, expr.getContext()); |
| 726 | + } |
| 727 | +} |
| 728 | + |
| 729 | +/// Simplify the map while exploiting information on the values in `operands`. |
| 730 | +// Use "unused attribute" marker to silence warning stemming from the inability |
| 731 | +// to see through the template expansion. |
| 732 | +static void LLVM_ATTRIBUTE_UNUSED |
| 733 | +simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { |
| 734 | + assert(map.getNumInputs() == operands.size() && "invalid operands for map"); |
| 735 | + SmallVector<AffineExpr> newResults; |
| 736 | + newResults.reserve(map.getNumResults()); |
| 737 | + for (AffineExpr expr : map.getResults()) { |
| 738 | + simplifyExprAndOperands(expr, operands); |
| 739 | + newResults.push_back(expr); |
| 740 | + } |
| 741 | + map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, |
| 742 | + map.getContext()); |
| 743 | +} |
| 744 | + |
581 | 745 | /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
|
582 | 746 | /// defining AffineApplyOp expression and operands.
|
583 | 747 | /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
|
@@ -1095,6 +1259,7 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
|
1095 | 1259 | SmallVector<Value, 8> resultOperands(oldOperands);
|
1096 | 1260 | composeAffineMapAndOperands(&map, &resultOperands);
|
1097 | 1261 | canonicalizeMapAndOperands(&map, &resultOperands);
|
| 1262 | + simplifyMapWithOperands(map, resultOperands); |
1098 | 1263 | if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
|
1099 | 1264 | resultOperands.begin()))
|
1100 | 1265 | return failure();
|
|
0 commit comments