Skip to content

Commit 965ec6d

Browse files
committed
[mlir] Add folder for shape.add
1 parent e9b1c97 commit 965ec6d

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def Shape_AddOp : Shape_Op<"add",
5555
// InferTypeOpInterface
5656
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
5757
}];
58+
59+
let hasFolder = 1;
5860
}
5961

6062
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#include "mlir/Dialect/Shape/IR/Shape.h"
1010

1111
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12+
#include "mlir/Dialect/CommonFolders.h"
1213
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1415
#include "mlir/Dialect/Traits.h"
1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/DialectImplementation.h"
19+
#include "mlir/IR/Matchers.h"
1820
#include "mlir/IR/PatternMatch.h"
1921
#include "mlir/IR/TypeUtilities.h"
2022
#include "mlir/Transforms/InliningUtils.h"
@@ -436,6 +438,15 @@ bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
436438
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
437439
}
438440

441+
OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
442+
// add(x, 0) -> x
443+
if (matchPattern(rhs(), m_Zero()))
444+
return lhs();
445+
446+
return constFoldBinaryOp<IntegerAttr>(operands,
447+
[](APInt a, APInt b) { return a + b; });
448+
}
449+
439450
//===----------------------------------------------------------------------===//
440451
// AssumingAllOp
441452
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,19 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
10201020

10211021
// -----
10221022

1023+
// Fold `add` for constant sizes.
1024+
// CHECK-LABEL: @fold_add_size
1025+
func @fold_add_size() -> !shape.size {
1026+
// CHECK: %[[RESULT:.*]] = shape.const_size 5
1027+
// CHECK: return %[[RESULT]] : !shape.size
1028+
%c2 = shape.const_size 2
1029+
%c3 = shape.const_size 3
1030+
%result = shape.add %c2, %c3 : !shape.size, !shape.size -> !shape.size
1031+
return %result : !shape.size
1032+
}
1033+
1034+
// -----
1035+
10231036
// Fold `mul` for constant sizes.
10241037
// CHECK-LABEL: @fold_mul_size
10251038
func @fold_mul_size() -> !shape.size {

0 commit comments

Comments
 (0)