Skip to content

Commit 3b947a9

Browse files
committed
[DisablFoldePattern][Arith][Vector] Temporarily disable fold patterns for Habana use-case.
1 parent 0ef39a8 commit 3b947a9

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,8 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
16881688
}
16891689

16901690
OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1691+
// @mshahneo: Temoporarily disabling folding for index_cast.
1692+
return OpFoldResult();
16911693
// index_cast(constant) -> constant
16921694
unsigned resultBitwidth = 64; // Default for index integer attributes.
16931695
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -2365,12 +2367,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
23652367

23662368
// Constant-fold constant operands over non-splat constant condition.
23672369
// select %cst_vec, %cst0, %cst1 => %cst2
2368-
if (auto cond =
2369-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2370-
if (auto lhs =
2371-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2372-
if (auto rhs =
2373-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2370+
if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
2371+
adaptor.getCondition())) {
2372+
if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
2373+
adaptor.getTrueValue())) {
2374+
if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
2375+
adaptor.getFalseValue())) {
23742376
SmallVector<Attribute> results;
23752377
results.reserve(static_cast<size_t>(cond.getNumElements()));
23762378
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2638,7 +2640,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
26382640
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
26392641
case AtomicRMWKind::minimumf:
26402642
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2641-
case AtomicRMWKind::maxnumf:
2643+
case AtomicRMWKind::maxnumf:
26422644
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
26432645
case AtomicRMWKind::minnumf:
26442646
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,9 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) {
17481748

17491749
// Fold extractOp with source coming from ShapeCast op.
17501750
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1751+
// @mshahneo: This is a temporary implementation to disable fold extract from
1752+
// shapecast
1753+
return Value();
17511754
// TODO: Canonicalization for dynamic position not implemented yet.
17521755
if (extractOp.hasDynamicPosition())
17531756
return Value();

0 commit comments

Comments
 (0)