diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index b96c0360316ad..39b36587e559b 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -3624,11 +3624,11 @@ def fir_LocalitySpecifierOp : fir_Op<"local", [IsolatedFromAbove]> { attr-dict }]; - let builders = [ - OpBuilder<(ins CArg<"mlir::TypeRange">:$result, - CArg<"mlir::StringAttr">:$sym_name, - CArg<"mlir::TypeAttr">:$type)> - ]; + // let builders = [ + // OpBuilder<(ins CArg<"mlir::TypeRange">:$result, + // CArg<"mlir::StringAttr">:$sym_name, + // CArg<"mlir::TypeAttr">:$type)> + // ]; let extraClassDeclaration = [{ mlir::BlockArgument getInitMoldArg() { diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index 687007d957225..cef95a13c5a50 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" diff --git a/flang/lib/Optimizer/Dialect/FIRCG/CGOps.cpp b/flang/lib/Optimizer/Dialect/FIRCG/CGOps.cpp index 19ad6bed512c7..d7d294ab621d8 100644 --- a/flang/lib/Optimizer/Dialect/FIRCG/CGOps.cpp +++ b/flang/lib/Optimizer/Dialect/FIRCG/CGOps.cpp @@ -12,6 +12,7 @@ #include "flang/Optimizer/Dialect/FIRCG/CGOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index ecfa2939e96a6..2381d44e7bb77 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index df6ce12215d26..7f2ec418af40a 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/APInt.h" diff --git a/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp b/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp index 1ea69f9059321..3da966ef9cc85 100644 --- a/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp +++ b/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp @@ -8,6 +8,7 @@ #include "Standalone/StandaloneDialect.h" #include "Standalone/StandaloneOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "Standalone/StandaloneTypes.h" using namespace mlir; diff --git a/mlir/examples/standalone/lib/Standalone/StandaloneOps.cpp b/mlir/examples/standalone/lib/Standalone/StandaloneOps.cpp index 55b66b51232f2..ef20049f1b76b 100644 --- a/mlir/examples/standalone/lib/Standalone/StandaloneOps.cpp +++ b/mlir/examples/standalone/lib/Standalone/StandaloneOps.cpp @@ -8,6 +8,7 @@ #include "Standalone/StandaloneOps.h" #include "Standalone/StandaloneDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #define GET_OP_CLASSES #include "Standalone/StandaloneOps.cpp.inc" diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index 489f348c8be52..fc4c3c75dcfda 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" @@ -144,7 +145,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -280,7 +282,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 708855f18cf45..af3cbb385d332 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" @@ -144,7 +145,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -280,7 +282,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 076a75a26619b..e3ab656c193a3 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" @@ -206,7 +207,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -395,7 +397,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index fb7c742a01802..81c9e8f9a143f 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" @@ -206,7 +207,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -395,7 +397,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index fb7c742a01802..81c9e8f9a143f 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" @@ -206,7 +207,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -395,7 +397,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 52881db87d86b..64aa4d8e8995d 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" @@ -429,7 +430,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp index b4b27e97d266e..e24714b408e29 100644 --- a/mlir/examples/transform/Ch2/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp index 4b2123fa71d31..5af5cc8ac5052 100644 --- a/mlir/examples/transform/Ch3/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp index fa0ffc9dc2e8a..86884d3bee7a0 100644 --- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -13,6 +13,7 @@ #include "MyExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE_MATCHER "transform-matcher" diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h index 2091faa6b0b02..333de6bbd8a05 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -114,6 +114,21 @@ class AffineDmaStartOp AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride = nullptr, Value elementsPerStride = nullptr); + static AffineDmaStartOp + create(OpBuilder &builder, Location location, Value srcMemRef, + AffineMap srcMap, ValueRange srcIndices, Value destMemRef, + AffineMap dstMap, ValueRange destIndices, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, Value numElements, + Value stride = nullptr, Value elementsPerStride = nullptr); + + static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef, + AffineMap srcMap, ValueRange srcIndices, + Value destMemRef, AffineMap dstMap, + ValueRange destIndices, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, + Value numElements, Value stride = nullptr, + Value elementsPerStride = nullptr); + /// Returns the operand index of the source memref. unsigned getSrcMemRefOperandIndex() { return 0; } @@ -319,6 +334,12 @@ class AffineDmaWaitOp static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements); + static AffineDmaWaitOp create(OpBuilder &builder, Location location, + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements); + static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, + Value numElements); static StringRef getOperationName() { return "affine.dma_wait"; } diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 7c50c2036ffdc..0fc3db8e993d8 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -59,15 +59,27 @@ class ConstantIntOp : public arith::ConstantOp { /// Build a constant int op that produces an integer of the specified width. static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width); + static ConstantIntOp create(OpBuilder &builder, Location location, + int64_t value, unsigned width); + static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value, + unsigned width); /// Build a constant int op that produces an integer of the specified type, /// which must be an integer type. static void build(OpBuilder &builder, OperationState &result, Type type, int64_t value); + static ConstantIntOp create(OpBuilder &builder, Location location, Type type, + int64_t value); + static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type, + int64_t value); /// Build a constant int op that produces an integer from an APInt static void build(OpBuilder &builder, OperationState &result, Type type, const APInt &value); + static ConstantIntOp create(OpBuilder &builder, Location location, Type type, + const APInt &value); + static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type, + const APInt &value); inline int64_t value() { return cast(arith::ConstantOp::getValue()).getInt(); @@ -85,6 +97,10 @@ class ConstantFloatOp : public arith::ConstantOp { /// Build a constant float op that produces a float of the specified type. static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value); + static ConstantFloatOp create(OpBuilder &builder, Location location, + FloatType type, const APFloat &value); + static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type, + const APFloat &value); inline APFloat value() { return cast(arith::ConstantOp::getValue()).getValue(); @@ -100,6 +116,9 @@ class ConstantIndexOp : public arith::ConstantOp { static ::mlir::TypeID resolveTypeID() { return TypeID::get(); } /// Build a constant int op that produces an index. static void build(OpBuilder &builder, OperationState &result, int64_t value); + static ConstantIndexOp create(OpBuilder &builder, Location location, + int64_t value); + static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value); inline int64_t value() { return cast(arith::ConstantOp::getValue()).getInt(); diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 4360055e78691..c5c984e09bf67 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -582,9 +582,9 @@ def LinalgCopyToMemrefOp : let assemblyFormat = "$target attr-dict `:` " "functional-type(operands, results) "; - let builders = [ - OpBuilder<(ins "Value":$target)>, - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)>, + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, @@ -997,12 +997,12 @@ def PackGreedilyOp : Op]>:$matmul_inner_dims_order); let results = (outs TransformHandleTypeInterface:$packed_op); - let builders = [ - OpBuilder<(ins "Value":$target, - "ArrayRef":$mixedMatmulPackedSizes, - "ArrayRef":$matmulPaddededSizesNextMultipleOf, - CArg<"ArrayRef", "{}">:$matmulDimsInnerDimsOrder)> - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target, + // "ArrayRef":$mixedMatmulPackedSizes, + // "ArrayRef":$matmulPaddededSizesNextMultipleOf, + // CArg<"ArrayRef", "{}">:$matmulDimsInnerDimsOrder)> + // ]; let assemblyFormat = [{ $target @@ -2509,10 +2509,10 @@ def HoistRedundantVectorTransfersOp : let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; - let builders = [ - OpBuilder<(ins "Value":$target, - CArg<"bool", "false">:$verify_non_zero_trip)>, - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target, + // CArg<"bool", "false">:$verify_non_zero_trip)>, + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, @@ -2546,9 +2546,9 @@ def HoistRedundantVectorBroadcastsOp : let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; - let builders = [ - OpBuilder<(ins "Value":$target)>, - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)>, + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, @@ -2623,9 +2623,9 @@ def ConvertConv2DToImg2ColOp : Op - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)> + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -2666,9 +2666,9 @@ def FlattenElementwiseLinalgOp : Op - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)> + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -2715,9 +2715,9 @@ def TransposeConv2DOp : Op - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)> + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -2761,9 +2761,9 @@ def TransposeMatmulOp : Op - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)> + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -2801,9 +2801,9 @@ def InsertSliceToCopyOp : let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; - let builders = [ - OpBuilder<(ins "Value":$target)>, - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)>, + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, @@ -2859,9 +2859,9 @@ def MapCopyToThreadsOp : `:` functional-type(operands, results) }]; - let builders = [ - OpBuilder<(ins "Value":$target)>, - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)>, + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, @@ -2910,9 +2910,9 @@ def WinogradConv2DOp : Op - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)> + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -2947,9 +2947,9 @@ def DecomposeWinogradOp : Op - ]; + // let builders = [ + // OpBuilder<(ins "Value":$target)> + // ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ac80926053a2d..118877ea145fd 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1751,9 +1751,9 @@ def DeclareMapperInfoOp : OpenMP_Op<"declare_mapper.info", [ parent DeclareMapperOp. }] # clausesDescription; - let builders = [ - OpBuilder<(ins CArg<"const DeclareMapperInfoOperands &">:$clauses)> - ]; + // let builders = [ + // OpBuilder<(ins CArg<"const DeclareMapperInfoOperands &">:$clauses)> + // ]; let extraClassDeclaration = [{ // Override BlockArgOpenMPOpInterface method because `map` clauses have no diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 1523762efc18f..7fecf89605f06 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -171,9 +171,9 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> { let arguments = (ins TypeAttr:$elementType); let results = (outs AnySignlessIntegerOrIndex:$result); - let builders = [ - OpBuilder<(ins "Type":$elementType)> - ]; + // let builders = [ + // OpBuilder<(ins "Type":$elementType)> + // ]; let assemblyFormat = [{ $elementType attr-dict `:` type($result) }]; diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td index ea94dfd8fbd2a..5a919395c0954 100644 --- a/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td @@ -54,9 +54,9 @@ class VariadicIntOp : SMTIntOp { let assemblyFormat = "$inputs attr-dict"; let builders = [ - OpBuilder<(ins "mlir::ValueRange":$inputs), [{ - build($_builder, $_state, $_builder.getType(), inputs); - }]>, + // OpBuilder<(ins "mlir::ValueRange":$inputs), [{ + // build($_builder, $_state, $_builder.getType(), inputs); + // }]>, ]; } diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index ec2c87ca1cf44..5d0c34ff8da34 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -3011,12 +3011,12 @@ def Vector_ScanOp : vector<4x8x16x32xf32>, vector<4x16x32xf32> ``` }]; - let builders = [ - OpBuilder<(ins "Value":$source, "Value":$initial_value, - "CombiningKind":$kind, - CArg<"int64_t", "0">:$reduction_dim, - CArg<"bool", "true">:$inclusive)> - ]; + // let builders = [ + // OpBuilder<(ins "Value":$source, "Value":$initial_value, + // "CombiningKind":$kind, + // CArg<"int64_t", "0">:$reduction_dim, + // CArg<"bool", "true">:$inclusive)> + // ]; let extraClassDeclaration = [{ VectorType getSourceType() { return ::llvm::cast(getSource().getType()); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 663c256c848df..75c3aea0792ac 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -30,6 +30,7 @@ namespace mlir { class Builder; class OpBuilder; +class ImplicitLocOpBuilder; /// This class implements `Optional` functionality for ParseResult. We don't /// directly use Optional here, because it provides an implicit conversion diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index f750a34a3b2ba..69cefbbc43e0a 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -71,6 +71,8 @@ class MethodParameter { StringRef getName() const { return name; } /// Returns true if the parameter has a default value. bool hasDefaultValue() const { return !defaultValue.empty(); } + StringRef getDefaultValue() const { return defaultValue; } + bool isOptional() const { return optional; } private: /// The C++ type. diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp index bd8b13c6516e2..d2a83da643a89 100644 --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" using namespace mlir; diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 910fe1b1d93c1..34e859bc9c86e 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, if (i32 == valTy) return val; return valTy.getWidth() > 32 - ? Value(rewriter.create(loc, i32, val)) - : Value(rewriter.create(loc, i32, val)); + ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) + : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type i32 = rewriter.getI32Type(); - return rewriter.create(loc, i32, value); + return LLVM::ConstantOp::create(rewriter, loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); - return rewriter.create(loc, llvmI1, value); + return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); } /// Returns the linear index used to access an element in the memref. @@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) - : rewriter.create(loc, i32, stride); - increment = rewriter.create(loc, increment, strideValue); + : LLVM::ConstantOp::create(rewriter, loc, i32, stride); + increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } index = - index ? rewriter.create(loc, index, increment) : increment; + index ? LLVM::AddOp::create(rewriter, loc, index, increment) : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } @@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); - Value maxThisDim = rewriter.create(loc, size, stride); + Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex - ? rewriter.create(loc, maxIndex, maxThisDim) + ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); - return rewriter.create(loc, maxIndexI32, byteWidthConst); + return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, @@ -132,13 +132,13 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = - rewriter.create(loc, i16, cacheSwizzleStride); - Value swizzleBit = rewriter.create( + LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); + Value swizzleBit = LLVM::ConstantOp::create(rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); - stride = rewriter.create(loc, cacheStrideZext, swizzleBit, + stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, /*isDisjoint=*/true); } else { - stride = rewriter.create(loc, i16, + stride = LLVM::ConstantOp::create(rewriter, loc, i16, rewriter.getI16IntegerAttr(0)); } // Get the number of elements. @@ -209,18 +209,18 @@ struct FatRawBufferCastLowering : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() - ? rewriter.create( + ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. - Value sizes = hasSizes ? rewriter.create( + Value sizes = hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, kSizePosInMemRefDescriptor) : Value{}; Value strides = hasSizes - ? rewriter.create( + ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, kStridePosInMemRefDescriptor) : Value{}; @@ -231,16 +231,16 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = rewriter.create( + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); - result = rewriter.create( + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); - result = rewriter.create(loc, result, offset, + result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, kOffsetPosInMemRefDescriptor); if (hasSizes) { - result = rewriter.create(loc, result, sizes, + result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, kSizePosInMemRefDescriptor); - result = rewriter.create( + result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); @@ -343,7 +343,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { if (storeData) { if (llvmBufferValType != llvmWantedDataType) { Value castForStore = - rewriter.create(loc, llvmBufferValType, storeData); + LLVM::BitcastOp::create(rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); @@ -352,7 +352,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForCmp = rewriter.create( + Value castForCmp = LLVM::BitcastOp::create(rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { @@ -383,17 +383,17 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); voffset = - voffset ? rewriter.create(loc, voffset, extraOffsetConst) + voffset ? LLVM::AddOp::create(rewriter, loc, voffset, extraOffsetConst) : extraOffsetConst; } - voffset = rewriter.create(loc, voffset, byteWidthConst); + voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - sgprOffset = rewriter.create(loc, sgprOffset, byteWidthConst); + sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) @@ -403,12 +403,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { llvm::SmallVector resultTypes(gpuOp->getNumResults(), llvmBufferValType); - Operation *lowered = rewriter.create(loc, resultTypes, args, + Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, ArrayRef()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { - replacement = rewriter.create(loc, llvmWantedDataType, + replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, replacement); } rewriter.replaceOp(gpuOp, replacement); @@ -465,12 +465,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { << chipset.majorVersion; Location loc = op->getLoc(); - rewriter.create(loc, ldsOnlyBits); + ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); rewriter.replaceOpWithNewOp(op); } else { Location loc = op->getLoc(); - rewriter.create(loc, 0); - rewriter.create(loc, -1); + ROCDL::WaitDscntOp::create(rewriter, loc, 0); + ROCDL::BarrierSignalOp::create(rewriter, loc, -1); rewriter.replaceOpWithNewOp(op, -1); } @@ -516,18 +516,18 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) - return rewriter.create( + return LLVM::BitcastOp::create(rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) - return rewriter.create( + return LLVM::BitcastOp::create(rewriter, loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); - return rewriter.create( + return LLVM::BitcastOp::create(rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input); } } @@ -549,8 +549,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); Type outputType = rewriter.getI32Type(); if (auto intType = dyn_cast(inputType)) - return rewriter.create(loc, outputType, input); - return rewriter.create(loc, outputType, input); + return LLVM::ZExtOp::create(rewriter, loc, outputType, input); + return LLVM::BitcastOp::create(rewriter, loc, outputType, input); } /// Push an input operand. If it is a float type, nothing to do. If it is @@ -576,7 +576,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - llvmInput = rewriter.create( + llvmInput = LLVM::BitcastOp::create(rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); @@ -613,7 +613,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) - castInput = rewriter.create(loc, i32, castInput); + castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } @@ -633,7 +633,7 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - output = rewriter.create( + output = LLVM::BitcastOp::create(rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { @@ -992,7 +992,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern { }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) - lowered = rewriter.create(loc, outType, lowered); + lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } @@ -1093,7 +1093,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { Operation *maybeCastBack = lowered; if (rawOutType != outType) maybeCastBack = - rewriter.create(loc, outType, lowered->getResult(0)); + LLVM::BitcastOp::create(rewriter, loc, outType, lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); @@ -1144,21 +1144,21 @@ struct TransposeLoadOpLowering case 4: { assert(numElements == 16); auto rocdlOp = - rewriter.create(loc, rocdlResultType, srcPtr); + ROCDL::ds_read_tr4_b64::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); auto rocdlOp = - rewriter.create(loc, rocdlResultType, srcPtr); + ROCDL::ds_read_tr6_b96::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); auto rocdlOp = - rewriter.create(loc, rocdlResultType, srcPtr); + ROCDL::ds_read_tr8_b64::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } @@ -1311,21 +1311,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { - Value longVec = rewriter.create(loc, v4i8); + Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { - longVec = rewriter.create( + longVec = LLVM::InsertElementOp::create(rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, @@ -1377,21 +1377,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { - Value longVec = rewriter.create(loc, packedVecType); + Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { - longVec = rewriter.create( + longVec = LLVM::InsertElementOp::create(rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( @@ -1449,53 +1449,53 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create(loc, intResultType, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else - existing = rewriter.create(loc, intResultType); + existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); - Value elem0 = rewriter.create(loc, source, c0); + Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); - source = rewriter.create(loc, v2); - source = rewriter.create(loc, source, elem0, c0); + source = LLVM::ZeroOp::create(rewriter, loc, v2); + source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); - sourceA = rewriter.create(loc, source, c0); - sourceB = rewriter.create(loc, source, c1); + sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); + sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkBf8F16Op::create(rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkFp8F16Op::create(rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkFp4F16Op::create(rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) - result = rewriter.create( + result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); @@ -1521,19 +1521,19 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) - sourceB = rewriter.create(loc, sourceA.getType()); + sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create(loc, i32, sourceA, sourceB, + result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create(loc, i32, sourceA, sourceB, + result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp( @@ -1558,16 +1558,16 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create( + result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create( + result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp( @@ -1612,14 +1612,14 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa(operandType)) { operand = - rewriter.create(loc, llvmSrcIntType, operand); + LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); - Value undefVec = rewriter.create(loc, llvmVecType); - operand = rewriter.create( + Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); + operand = LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); - operand = rewriter.create(loc, llvmType, operand); + operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; @@ -1706,14 +1706,14 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { bool boundCtrl = DppOp->getAttrOfType("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes - auto dppMovOp = rewriter.create( + auto dppMovOp = ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { - result = rewriter.create(loc, llvmSrcIntType, result); + result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa(srcType)) { - result = rewriter.create(loc, srcType, result); + result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } @@ -1747,7 +1747,7 @@ struct AMDGPUSwizzleBitModeLowering SmallVector swizzled; for (Value v : decomposed) { Value res = - rewriter.create(loc, v.getType(), v, maskValue); + ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 3b143ca1ef9ce..5a9d946c2fe4e 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc, Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { if (predicate == arith::CmpIPredicate::sgt) - value = builder.create(loc, value, *valueIt); + value = arith::MaxSIOp::create(builder, loc, value, *valueIt); else - value = builder.create(loc, value, *valueIt); + value = arith::MinSIOp::create(builder, loc, value, *valueIt); } return value; @@ -154,8 +154,8 @@ class AffineForLowering : public OpRewritePattern { Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); Value step = - rewriter.create(loc, op.getStepAsInt()); - auto scfForOp = rewriter.create(loc, lowerBound, upperBound, + arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt()); + auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, op.getInits()); rewriter.eraseBlock(scfForOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(), @@ -197,7 +197,7 @@ class AffineParallelLowering : public OpRewritePattern { } steps.reserve(op.getSteps().size()); for (int64_t step : op.getSteps()) - steps.push_back(rewriter.create(loc, step)); + steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step)); // Get the terminator op. auto affineParOpTerminator = @@ -205,7 +205,7 @@ class AffineParallelLowering : public OpRewritePattern { scf::ParallelOp parOp; if (op.getResults().empty()) { // Case with no reduction operations/return values. - parOp = rewriter.create(loc, lowerBoundTuple, + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, upperBoundTuple, steps, /*bodyBuilderFn=*/nullptr); rewriter.eraseBlock(parOp.getBody()); @@ -233,7 +233,7 @@ class AffineParallelLowering : public OpRewritePattern { identityVals.push_back( arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); } - parOp = rewriter.create( + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, /*bodyBuilderFn=*/nullptr); @@ -261,7 +261,7 @@ class AffineParallelLowering : public OpRewritePattern { Value reductionResult = arith::getReductionOp( reductionOpValue, rewriter, loc, reductionBody.getArgument(0), reductionBody.getArgument(1)); - rewriter.create(loc, reductionResult); + scf::ReduceReturnOp::create(rewriter, loc, reductionResult); } rewriter.replaceOp(op, parOp.getResults()); return success(); @@ -278,7 +278,7 @@ class AffineIfLowering : public OpRewritePattern { // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create(loc, 0); + Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector operands(op.getOperands()); auto operandsRef = llvm::ArrayRef(operands); @@ -298,17 +298,17 @@ class AffineIfLowering : public OpRewritePattern { auto pred = isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create(loc, pred, affResult, zeroConstant); + arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant); cond = cond - ? rewriter.create(loc, cond, cmpVal).getResult() + ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult() : cmpVal; } cond = cond ? cond - : rewriter.create(loc, /*value=*/1, + : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1); bool hasElseRegion = !op.getElseRegion().empty(); - auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, + auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond, hasElseRegion); rewriter.inlineRegionBefore(op.getThenRegion(), &ifOp.getThenRegion().back()); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 3596b3235a631..d0155b12a6ec2 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -89,9 +89,9 @@ static Value castF32To(Type desType, Value f32, Location loc, if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create(loc, desType, f32); + return arith::TruncFOp::create(rewriter, loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create(loc, desType, f32); + return arith::ExtFOp::create(rewriter, loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -113,7 +113,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Type outElemType = getElementTypeOrSelf(op.getOut().getType()); VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloat = rewriter.create( + Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter, loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); rewriter.replaceOp(op, result); @@ -121,7 +121,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, } int64_t numElements = inVecType.getNumElements(); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); VectorType outType = cast(op.getOut().getType()); @@ -129,10 +129,10 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Value zerodSplat = rewriter.createOrFold(loc, outType, zero); Value scalarIn = - rewriter.create(loc, in, ArrayRef{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef{}); Value scalarExt = - rewriter.create(loc, outElemType, scalarIn); - Value result = rewriter.create(loc, scalarExt, zerodSplat, + arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarExt, zerodSplat, ArrayRef{}); rewriter.replaceOp(op, result); return success(); @@ -145,32 +145,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, if (inVecType.getRank() > 1) { inVecType = VectorType::get(SmallVector{numElements}, inVecType.getElementType()); - in = rewriter.create(loc, inVecType, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; - Value inSlice = rewriter.create( + Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i, elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; j += 2) { if (i + j + 1 < numElements) { // Convert two 8-bit elements - Value asFloats = rewriter.create( + Value asFloats = amdgpu::ExtPackedFp8Op::create(rewriter, loc, extResType, inSlice, j / 2); Type desType = VectorType::get(2, outElemType); Value asType = castF32To(desType, asFloats, loc, rewriter); - result = rewriter.create( + result = vector::InsertStridedSliceOp::create(rewriter, loc, asType, result, i + j, 1); } else { // Convert a 8-bit element - Value asFloat = rewriter.create( + Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2); Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create(loc, asType, result, i + j); + result = vector::InsertOp::create(rewriter, loc, asType, result, i + j); } } } if (inVecType.getRank() != outType.getRank()) { - result = rewriter.create(loc, outType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outType, result); } rewriter.replaceOp(op, result); @@ -182,9 +182,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { if (type.isF32()) return value; if (type.getIntOrFloatBitWidth() < 32) - return rewriter.create(loc, rewriter.getF32Type(), value); + return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value); if (type.getIntOrFloatBitWidth() > 32) - return rewriter.create(loc, rewriter.getF32Type(), value); + return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value); llvm_unreachable("The only 32-bit float type is f32"); } @@ -224,13 +224,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, loc, arith::CmpFPredicate::OEQ, source, negInf); Value isNan = rewriter.createOrFold( loc, arith::CmpFPredicate::UNO, source, source); - Value isNonFinite = rewriter.create( - loc, rewriter.create(loc, isInf, isNegInf), isNan); + Value isNonFinite = arith::OrIOp::create(rewriter, + loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), isNan); - Value clampedBelow = rewriter.create(loc, source, minCst); - Value clamped = rewriter.create(loc, clampedBelow, maxCst); + Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst); + Value clamped = arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst); Value res = - rewriter.create(loc, isNonFinite, source, clamped); + arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped); return res; } @@ -264,24 +264,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, VectorType truncResType = VectorType::get(4, outElemType); if (!inVectorTy) { Value asFloat = castToF32(in, loc, rewriter); - Value asF8s = rewriter.create( + Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); - Value result = rewriter.create(loc, asF8s, 0); + Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); if (outVecType.getShape().empty()) { Value scalarIn = - rewriter.create(loc, in, ArrayRef{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = - rewriter.create(loc, outElemType, scalarIn); - Value result = rewriter.create(loc, scalarTrunc, zero, + arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero, ArrayRef{}); rewriter.replaceOp(op, result); return success(); @@ -294,32 +294,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, inVectorTy.getElementType()); - in = rewriter.create(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value elemA = rewriter.create(loc, in, i + j); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { - Value elemB = rewriter.create(loc, in, i + j + 1); + Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } - thisResult = rewriter.create( + thisResult = amdgpu::PackedTrunc2xFp8Op::create(rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); } if (elemsThisOp < 4) - thisResult = rewriter.create( + thisResult = vector::ExtractStridedSliceOp::create(rewriter, loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create(loc, thisResult, + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -347,10 +347,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( // Handle the case where input type is not a vector type if (!inVectorTy) { - auto sourceB = rewriter.create(loc, rewriter.getF32Type()); + auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); Value asF16s = - rewriter.create(loc, truncResType, in, sourceB); - Value result = rewriter.create(loc, asF16s, 0); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB); + Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0); rewriter.replaceOp(op, result); return success(); } @@ -362,7 +362,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, inVectorTy.getElementType()); - in = rewriter.create(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } // Handle the vector case. We also handle the (uncommon) case where the vector @@ -370,25 +370,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( for (int64_t i = 0; i < numElements; i += 2) { int64_t elemsThisOp = std::min(numElements, i + 2) - i; Value thisResult = nullptr; - Value elemA = rewriter.create(loc, in, i); - Value elemB = rewriter.create(loc, rewriter.getF32Type()); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i); + Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); if (elemsThisOp == 2) { - elemB = rewriter.create(loc, in, i + 1); + elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1); } thisResult = - rewriter.create(loc, truncResType, elemA, elemB); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB); // Place back the truncated result into the possibly larger vector. If we // are operating on a size 2 vector, these operations should be folded away - thisResult = rewriter.create( + thisResult = vector::ExtractStridedSliceOp::create(rewriter, loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create(loc, thisResult, + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp index cbe0b3fda3410..40bf8f1cffb04 100644 --- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp +++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp @@ -74,14 +74,14 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern { VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue()); - auto constantOp1D = rewriter.create(loc, denseAttr1D); + auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to write vector to tile // slice. - auto nextTile = b.create( + auto nextTile = arm_sme::InsertTileSliceOp::create(b, loc, tileType, constantOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a5c08a6378021..c5b3fdc536f72 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -110,7 +110,7 @@ class CmpFOpConversion : public OpConversionPattern { emitc::CmpPredicate predicate; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: { - auto constant = rewriter.create( + auto constant = emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(/*value=*/false)); rewriter.replaceOp(op, constant); @@ -179,7 +179,7 @@ class CmpFOpConversion : public OpConversionPattern { return success(); } case arith::CmpFPredicate::AlwaysTrue: { - auto constant = rewriter.create( + auto constant = emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(/*value=*/true)); rewriter.replaceOp(op, constant); @@ -189,7 +189,7 @@ class CmpFOpConversion : public OpConversionPattern { // Compare the values naively auto cmpResult = - rewriter.create(op.getLoc(), op.getType(), predicate, + emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate, adaptor.getLhs(), adaptor.getRhs()); // Adjust the results for unordered/ordered semantics @@ -213,7 +213,7 @@ class CmpFOpConversion : public OpConversionPattern { Value isNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is NaN exactly when it compares unequal to itself. - return rewriter.create( + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); } @@ -221,7 +221,7 @@ class CmpFOpConversion : public OpConversionPattern { Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is not NaN exactly when it compares equal to itself. - return rewriter.create( + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); } @@ -231,7 +231,7 @@ class CmpFOpConversion : public OpConversionPattern { Location loc, Value first, Value second) const { auto firstIsNaN = isNaN(rewriter, loc, first); auto secondIsNaN = isNaN(rewriter, loc, second); - return rewriter.create(loc, rewriter.getI1Type(), + return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(), firstIsNaN, secondIsNaN); } @@ -241,7 +241,7 @@ class CmpFOpConversion : public OpConversionPattern { Value first, Value second) const { auto firstIsNotNaN = isNotNaN(rewriter, loc, first); auto secondIsNotNaN = isNotNaN(rewriter, loc, second); - return rewriter.create(loc, rewriter.getI1Type(), + return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(), firstIsNotNaN, secondIsNotNaN); } }; @@ -378,9 +378,9 @@ class CastConversion : public OpConversionPattern { Type attrType = (emitc::isPointerWideType(operandType)) ? rewriter.getIndexType() : operandType; - auto constOne = rewriter.create( + auto constOne = emitc::ConstantOp::create(rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType)); - auto oneAndOperand = rewriter.create( + auto oneAndOperand = emitc::BitwiseAndOp::create(rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne); rewriter.replaceOpWithNewOp(op, opReturnType, oneAndOperand); @@ -467,7 +467,7 @@ class BinaryUIOpConversion final : public OpConversionPattern { Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); auto newDivOp = - rewriter.create(uiBinOp.getLoc(), unsignedType, + EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType, ArrayRef{lhsAdapted, rhsAdapted}); Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy); rewriter.replaceOp(uiBinOp, resultAdapted); @@ -588,38 +588,38 @@ class ShiftOpConversion : public OpConversionPattern { // Add a runtime check for overflow Value width; if (emitc::isPointerWideType(type)) { - Value eight = rewriter.create( + Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType, rewriter.getIndexAttr(8)); - emitc::CallOpaqueOp sizeOfCall = rewriter.create( + emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef{eight}); - width = rewriter.create(op.getLoc(), rhsType, eight, + width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight, sizeOfCall.getResult(0)); } else { - width = rewriter.create( + width = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType, rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); } - Value excessCheck = rewriter.create( + Value excessCheck = emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); // Any concrete value is a valid refinement of poison. - Value poison = rewriter.create( + Value poison = emitc::ConstantOp::create(rewriter, op.getLoc(), arithmeticType, (isa(arithmeticType) ? rewriter.getIntegerAttr(arithmeticType, 0) : rewriter.getIndexAttr(0))); - emitc::ExpressionOp ternary = rewriter.create( + emitc::ExpressionOp ternary = emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false); Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); auto currentPoint = rewriter.getInsertionPoint(); rewriter.setInsertionPointToStart(&bodyBlock); Value arithmeticResult = - rewriter.create(op.getLoc(), arithmeticType, lhs, rhs); - Value resultOrPoison = rewriter.create( + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); + Value resultOrPoison = emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); - rewriter.create(op.getLoc(), resultOrPoison); + emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison); rewriter.setInsertionPoint(op->getBlock(), currentPoint); Value result = adaptValueType(ternary, rewriter, type); @@ -700,11 +700,11 @@ class FtoICastOpConversion : public OpConversionPattern { /*isSigned=*/false); } - Value result = rewriter.create( + Value result = emitc::CastOp::create(rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands()); if (isa(castOp)) { - result = rewriter.create(castOp.getLoc(), dstType, result); + result = emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result); } rewriter.replaceOp(castOp, result); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index f7bf581adc9e3..f7d4963b5e293 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -293,10 +293,10 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { - return rewriter.create(op.getLoc(), llvm1DVectorTy, + return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy, adaptor.getIn()); } - return rewriter.create(op.getLoc(), llvm1DVectorTy, + return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy, adaptor.getIn()); }, rewriter); @@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); - Value addOverflow = rewriter.create( + Value addOverflow = LLVM::UAddWithOverflowOp::create(rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = - rewriter.create(loc, addOverflow, 0); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0); Value overflowExtracted = - rewriter.create(loc, addOverflow, 1); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } @@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering::matchAndRewrite( "LLVM dialect should support all signless integer types"); using LLVMExtOp = std::conditional_t; - Value lhsExt = rewriter.create(loc, wideType, adaptor.getLhs()); - Value rhsExt = rewriter.create(loc, wideType, adaptor.getRhs()); - Value mulExt = rewriter.create(loc, wideType, lhsExt, rhsExt); + Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs()); + Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs()); + Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. - Value low = rewriter.create(loc, resultType, mulExt); - Value shiftVal = rewriter.create(loc, shiftValAttr); - Value highExt = rewriter.create(loc, mulExt, shiftVal); - Value high = rewriter.create(loc, resultType, highExt); + Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt); + Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr); + Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal); + Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt); rewriter.replaceOp(op, {low, high}); return success(); @@ -435,7 +435,7 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create( + return LLVM::ICmpOp::create(rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); @@ -471,7 +471,7 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create( + return LLVM::FCmpOp::create(rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df853a5e..c93cf13f831ef 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -117,11 +117,11 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value, if (auto vectorType = dyn_cast(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); - return builder.create(loc, vectorType, attr); + return spirv::ConstantOp::create(builder, loc, vectorType, attr); } if (auto intType = dyn_cast(type)) - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, builder.getIntegerAttr(type, value)); return nullptr; @@ -418,18 +418,18 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. - Value lhsAbs = builder.create(loc, type, lhs); - Value rhsAbs = builder.create(loc, type, rhs); - Value abs = builder.create(loc, lhsAbs, rhsAbs); + Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); + Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); + Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) - isPositive = builder.create(loc, lhs, lhsAbs); + isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); else - isPositive = builder.create(loc, rhs, rhsAbs); - Value absNegate = builder.create(loc, type, abs); - return builder.create(loc, type, isPositive, abs, absNegate); + isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); + Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); + return spirv::SelectOp::create(builder, loc, type, isPositive, abs, absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. @@ -601,12 +601,12 @@ struct ExtSII1Pattern final : public OpConversionPattern { Value allOnes; if (auto intTy = dyn_cast(dstType)) { unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create( + allOnes = spirv::ConstantOp::create(rewriter, loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create( + allOnes = spirv::ConstantOp::create(rewriter, loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); @@ -653,7 +653,7 @@ struct ExtSIPattern final : public OpConversionPattern { // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. - auto shiftLOp = rewriter.create( + auto shiftLOp = spirv::ShiftLeftLogicalOp::create(rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right @@ -757,9 +757,9 @@ struct TruncII1Pattern final : public OpConversionPattern { auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create( + Value maskedSrc = spirv::BitwiseAndOp::create(rewriter, loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create(loc, maskedSrc, mask); + Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -914,9 +914,9 @@ class CmpIOpBooleanPattern final : public OpConversionPattern { if (auto vectorType = dyn_cast(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = - rewriter.create(op.getLoc(), type, adaptor.getLhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); Value extRhs = - rewriter.create(op.getLoc(), type, adaptor.getRhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, extRhs); @@ -1067,12 +1067,12 @@ class CmpFOpNanNonePattern final : public OpConversionPattern { replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); + replace = spirv::LogicalNotOp::create(rewriter, loc, replace); } rewriter.replaceOp(op, replace); @@ -1094,17 +1094,17 @@ class AddUIExtendedOpPattern final ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); - Value result = rewriter.create(loc, adaptor.getLhs(), + Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value sumResult = rewriter.create( + Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create( + Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create(loc, carryValue, one); + Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); @@ -1125,11 +1125,11 @@ class MulIExtendedOpPattern final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = - rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); + SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value low = rewriter.create(loc, result, + Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, llvm::ArrayRef(0)); - Value high = rewriter.create(loc, result, + Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); @@ -1183,19 +1183,19 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern { Location loc = op.getLoc(); Value spirvOp = - rewriter.create(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { rewriter.replaceOp(op, spirvOp); return success(); } - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create(loc, dstType, lhsIsNan, + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create(loc, dstType, rhsIsNan, + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); @@ -1237,7 +1237,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern { Location loc = op.getLoc(); Value spirvOp = - rewriter.create(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards() || bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { @@ -1245,12 +1245,12 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern { return success(); } - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create(loc, dstType, lhsIsNan, + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, adaptor.getRhs(), spirvOp); - Value select2 = rewriter.create(loc, dstType, rhsIsNan, + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 9c6de938a7108..2295d661d44a5 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -41,10 +41,10 @@ class Sdot2dLoweringPattern : public OpRewritePattern { Value c2d = op.getC(); Location loc = op.getLoc(); Value b1d = - rewriter.create(loc, flattenedVectorType, b2d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d); Value c1d = - rewriter.create(loc, flattenedVectorType, c2d); - Value newOp = rewriter.create(loc, op.getRes().getType(), op.getA(), + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d); + Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(), op.getA(), b1d, c1d); rewriter.replaceOp(op, {newOp}); return success(); diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 21ea444e31821..cf46e0c10df09 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -45,37 +45,37 @@ static Operation *createLoadTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( + return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( + return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( + return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( + return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( + return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( + return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( + return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( + return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( + return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( + return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); break; } @@ -91,37 +91,37 @@ static Operation *createStoreTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( + return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( + return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( + return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( + return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( + return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( + return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( + return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( + return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( + return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( + return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr, tileId, tileSliceI32); } } @@ -146,15 +146,15 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, // Move to the first operation in the function. rewriter.setInsertionPointToStart(&func.getBlocks().front()); // Create an alloca matching the tile size of the `tileOp`. - auto vscale = rewriter.create(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto tileElementType = tileOp.getTileType().getElementType(); auto memrefType = MemRefType::get( {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); auto minElementsOp = - rewriter.create(loc, minElements); - auto vectorLen = rewriter.create(loc, vscale, minElementsOp); - auto alloca = rewriter.create( + arith::ConstantIndexOp::create(rewriter, loc, minElements); + auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp); + auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType, ValueRange{vectorLen, vectorLen}); return alloca; } @@ -293,9 +293,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { Value tileMemory, Value sliceIndex) const { auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); auto descriptor = - rewriter.create(loc, llvmType, tileMemory); - auto zero = rewriter.create(loc, 0, /*width=*/64); - auto sliceIndexI64 = rewriter.create( + UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory); + auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64); + auto sliceIndexI64 = arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), sliceIndex); return getStridedElementPtr( static_cast(rewriter), loc, @@ -309,27 +309,27 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { arm_sme::ArmSMETileType tileType, VectorType sliceType, IntegerAttr tileId, Value sliceIndex) const { // Cast the slice index to an i32. - auto sliceIndexI32 = rewriter.create( + auto sliceIndexI32 = arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create an all-true predicate for the slice. auto predicateType = sliceType.clone(rewriter.getI1Type()); - auto allTruePredicate = rewriter.create( + auto allTruePredicate = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Create padding vector (never used due to all-true predicate). - auto padVector = rewriter.create(loc, sliceType); + auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType); // Get a pointer to the current slice. auto slicePtr = getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); // Read the value of the current slice from ZA. - auto currentTileSlice = rewriter.create( + auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(rewriter, loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); // Load the new tile slice back from memory into ZA. createLoadTileSliceIntrinsic( rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, allTruePredicate, slicePtr, tileId, sliceIndexI32); // Store the current tile slice to memory. - auto zero = rewriter.create(loc, 0); - rewriter.create(loc, currentTileSlice, tileAlloca, + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca, ValueRange{sliceIndex, zero}); } @@ -341,12 +341,12 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { RewriterBase::InsertionGuard guard(rewriter); // Create an scf.for over all tile slices. auto minNumElts = - rewriter.create(loc, sliceType.getDimSize(0)); - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create( - loc, minNumElts, rewriter.create(loc)); - auto step = rewriter.create(loc, 1); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = arith::MulIOp::create(rewriter, + loc, minNumElts, vector::VectorScaleOp::create(rewriter, loc)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Emit a swap for each tile slice. rewriter.setInsertionPointToStart(forOp.getBody()); auto sliceIndex = forOp.getInductionVar(); @@ -479,7 +479,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { // // This holds for all tile sizes. int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); - rewriter.create( + arm_sme::aarch64_sme_zero::create(rewriter, loc, rewriter.getI32IntegerAttr(zeroMask)); // Create a placeholder op to preserve dataflow. @@ -513,7 +513,7 @@ struct LoadTileSliceConversion auto tileSlice = loadTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create( + auto tileSliceI32 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. @@ -559,7 +559,7 @@ struct StoreTileSliceConversion auto tileSlice = storeTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create( + auto tileSliceI32 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI32Type(), tileSlice); auto maskOp = storeTileSliceOp.getMask(); @@ -595,26 +595,26 @@ struct InsertTileSliceConversion auto tileSlice = insertTileSliceOp.getTileSliceIndex(); // Cast tile slice from index to i32 for intrinsic. - auto tileSliceI32 = rewriter.create( + auto tileSliceI32 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create( + auto one = arith::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create(loc, predTy, one); + auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one); // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. switch (insertTileSliceOp.getLayout()) { case arm_sme::TileSliceLayout::Horizontal: - rewriter.create( + arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId, tileSliceI32, allActiveMask, insertTileSliceOp.getVector()); break; case arm_sme::TileSliceLayout::Vertical: - rewriter.create( + arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId, tileSliceI32, allActiveMask, insertTileSliceOp.getVector()); break; @@ -646,15 +646,15 @@ struct ExtractTileSliceConversion // Create an 'all true' predicate for the tile slice. auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); - auto allTruePredicate = rewriter.create( + auto allTruePredicate = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Zero destination/fallback for tile slice extraction. - auto zeroVector = rewriter.create( + auto zeroVector = arith::ConstantOp::create(rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType)); // Cast tile slice from index to i32 for intrinsic. - auto sliceIndexI32 = rewriter.create( + auto sliceIndexI32 = arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. @@ -743,7 +743,7 @@ struct OuterProductOpConversion Value acc = outerProductOp.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create(loc, resultVectorType); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType); zero.setTileId(tileId); acc = zero; } @@ -754,14 +754,14 @@ struct OuterProductOpConversion if (!lhsMask || !rhsMask) { auto predTy = outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create( + Value allActiveMask = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } // Create 'arm_sme.intr.mopa' outer product intrinsic. - rewriter.create(loc, tileId, lhsMask, rhsMask, + arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask, outerProductOp.getLhs(), outerProductOp.getRhs()); @@ -792,7 +792,7 @@ struct OuterProductWideningOpConversion Value acc = op.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create(loc, op.getResultType()); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType()); zero.setTileId(tileId); acc = zero; } @@ -801,13 +801,13 @@ struct OuterProductWideningOpConversion Value rhsMask = op.getRhsMask(); if (!lhsMask || !rhsMask) { auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create( + Value allActiveMask = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } - rewriter.create( + OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); // The outerproduct intrinsics have no result, replace @@ -843,13 +843,13 @@ struct StreamingVLOpConversion auto *intrOp = [&]() -> Operation * { switch (streamingVlOp.getTypeSize()) { case arm_sme::TypeSize::Byte: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Half: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Word: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Double: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); } llvm_unreachable("unknown type size in StreamingVLOpConversion"); }(); @@ -872,7 +872,7 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) { if (zeroOpsToMerge.size() <= 1) return; IRRewriter rewriter(zeroOpsToMerge.front()); - rewriter.create( + arm_sme::aarch64_sme_zero::create(rewriter, zeroOpsToMerge.front().getLoc(), rewriter.getI32IntegerAttr(mergedZeroMask)); for (auto zeroOp : zeroOpsToMerge) diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 458628c29c6ac..3591ed65f8d84 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -39,7 +39,7 @@ SmallVector getMemrefIndices(ValueRange indices, unsigned rank, auto tileSliceOffset = tileSliceIndex; auto baseIndexPlusTileSliceOffset = - rewriter.create(loc, indices[0], tileSliceOffset); + arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); outIndices.push_back(indices[1]); @@ -59,10 +59,10 @@ FailureOr createLoadStoreForOverTileSlices( if (memrefIndices.size() != 2) return rewriter.notifyMatchFailure(loc, "invalid number of indices"); - auto minTileSlices = rewriter.create( + auto minTileSlices = arith::ConstantIndexOp::create(rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = - rewriter.create(loc, rewriter.getIndexType()); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); @@ -70,7 +70,7 @@ FailureOr createLoadStoreForOverTileSlices( // elements in a vector of SVL bits for a given element type (SVL_B, // SVL_H, ..., SVL_Q). auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); Value predicate; Value upperBound; @@ -82,28 +82,28 @@ FailureOr createLoadStoreForOverTileSlices( // The upper bound of the loop must be clamped at `numTileSlices` as // `vector.create_mask` allows operands to be greater than the size of a // dimension. - auto numRowI64 = rewriter.create( + auto numRowI64 = arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), maskDim0); - auto numTileSlicesI64 = rewriter.create( + auto numTileSlicesI64 = arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), numTileSlices); auto upperBoundI64 = - rewriter.create(loc, numRowI64, numTileSlicesI64); - upperBound = rewriter.create( + arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64); + upperBound = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), upperBoundI64); predicate = - rewriter.create(loc, predicateType, maskDim1); + vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1); } else { upperBound = numTileSlices; // No mask. Create an 'all true' predicate for the tile slice. - predicate = rewriter.create( + predicate = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(predicateType, true)); } bool hasCarriedArgs = bool(initTile); - auto lowerBound = rewriter.create(loc, 0); - auto step = rewriter.create(loc, 1); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step, + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, hasCarriedArgs ? ValueRange{initTile} : ValueRange{}); @@ -118,7 +118,7 @@ FailureOr createLoadStoreForOverTileSlices( assert(bool(nextTile) == hasCarriedArgs); if (nextTile) - rewriter.create(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } @@ -194,9 +194,9 @@ struct TileLoadOpConversion : public OpRewritePattern { // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive // rows however, no load will occur so these need to be zeroed. - initTile = rewriter.create(loc, tileType); + initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType); } else { - initTile = rewriter.create(loc, tileType); + initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); } // Create a loop to load the active tile slices from memory. @@ -207,7 +207,7 @@ struct TileLoadOpConversion : public OpRewritePattern { Value currentTile) -> Value { // Create 'arm_sme.load_tile_slice' to load tile slice from memory // into tile. - return rewriter.create( + return arm_sme::LoadTileSliceOp::create(rewriter, loc, tileType, tileLoadOp.getBase(), predicate, currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); }); @@ -283,21 +283,21 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto numRows = createMaskOp.getOperands()[0]; auto numCols = createMaskOp.getOperands()[1]; - auto numColsI32 = rewriter.create( + auto numColsI32 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI32Type(), numCols); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. - auto step = rewriter.create(loc, 1); - auto minTileSlices = rewriter.create( + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create(rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - auto forOp = rewriter.create(loc, lowerBound, numTileSlices, + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -306,16 +306,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto currentTile = forOp.getRegionIterArg(0); // Combine masks. - auto rowIsActive = rewriter.create( + auto rowIsActive = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); - auto rowIsActiveI32 = rewriter.create( + auto rowIsActiveI32 = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), rowIsActive); - auto mask = rewriter.create(loc, rowIsActiveI32, numColsI32); + auto mask = arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32); auto maskIndex = - rewriter.create(loc, rewriter.getIndexType(), mask); + arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), mask); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto maskOp1D = rewriter.create( + auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType, maskIndex.getResult()); auto memrefIndices = getMemrefIndices( @@ -324,17 +324,17 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion // Splat pad into 1-D vector matching type of tile slice. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - auto pad1DOp = rewriter.create(loc, tileSliceType, padOp); + auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp); - auto loadSlice = rewriter.create( + auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D, /*passthru=*/pad1DOp); // Create 'arm_sme.insert_tile_slice' to insert slice into tile. - auto insertSlice = rewriter.create( + auto insertSlice = arm_sme::InsertTileSliceOp::create(rewriter, loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex, tileLoadOp.getLayout()); - rewriter.create(loc, insertSlice.getResult()); + scf::YieldOp::create(rewriter, loc, insertSlice.getResult()); rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 94f7caa315cf7..e1d28f19ee236 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -203,7 +203,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; - builder.create(name, type).setPrivate(); + func::FuncOp::create(builder, name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); @@ -254,15 +254,15 @@ static void addResumeFunction(ModuleOp module) { auto voidTy = LLVM::LLVMVoidType::get(ctx); Type ptrType = AsyncAPI::opaquePointerType(ctx); - auto resumeOp = moduleBuilder.create( + auto resumeOp = LLVM::LLVMFuncOp::create(moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); - blockBuilder.create(resumeOp.getArgument(0)); - blockBuilder.create(ValueRange()); + LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0)); + LLVM::ReturnOp::create(blockBuilder, ValueRange()); } //===----------------------------------------------------------------------===// @@ -282,7 +282,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter { // in patterns for other dialects. auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - auto cast = builder.create(loc, type, inputs); + auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return cast.getResult(0); }; @@ -343,8 +343,8 @@ class CoroIdOpConversion : public AsyncOpConversionPattern { // Constants for initializing coroutine frame. auto constZero = - rewriter.create(loc, rewriter.getI32Type(), 0); - auto nullPtr = rewriter.create(loc, ptrType); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp( @@ -372,32 +372,32 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern { // Get coroutine frame size: @llvm.coro.size.i64. Value coroSize = - rewriter.create(loc, rewriter.getI64Type()); + LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type()); // Get coroutine frame alignment: @llvm.coro.align.i64. Value coroAlign = - rewriter.create(loc, rewriter.getI64Type()); + LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type()); // Round up the size to be multiple of the alignment. Since aligned_alloc // requires the size parameter be an integral multiple of the alignment // parameter. auto makeConstant = [&](uint64_t c) { - return rewriter.create(op->getLoc(), + return LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(), c); }; - coroSize = rewriter.create(op->getLoc(), coroSize, coroAlign); + coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign); coroSize = - rewriter.create(op->getLoc(), coroSize, makeConstant(1)); + LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1)); Value negCoroAlign = - rewriter.create(op->getLoc(), makeConstant(0), coroAlign); + LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign); coroSize = - rewriter.create(op->getLoc(), coroSize, negCoroAlign); + LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign); // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( rewriter, op->getParentOfType(), rewriter.getI64Type()); if (failed(allocFuncOp)) return failure(); - auto coroAlloc = rewriter.create( + auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. @@ -427,7 +427,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern { // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = - rewriter.create(loc, ptrType, adaptor.getOperands()); + LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands()); // Free the memory. auto freeFuncOp = @@ -455,13 +455,13 @@ class CoroEndOpConversion : public OpConversionPattern { matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. - auto constFalse = rewriter.create( + auto constFalse = LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); - auto noneToken = rewriter.create(op->getLoc()); + auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc()); // Mark the end of a coroutine: @llvm.coro.end. auto coroHdl = adaptor.getHandle(); - rewriter.create( + LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), ValueRange({coroHdl, constFalse, noneToken})); rewriter.eraseOp(op); @@ -534,12 +534,12 @@ class CoroSuspendOpConversion : public OpConversionPattern { auto loc = op->getLoc(); // This is not a final suspension point. - auto constFalse = rewriter.create( + auto constFalse = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroState = adaptor.getState(); - auto coroSuspend = rewriter.create( + auto coroSuspend = LLVM::CoroSuspendOp::create(rewriter, loc, i8, ValueRange({coroState, constFalse})); // Cast return code to i32. @@ -551,7 +551,7 @@ class CoroSuspendOpConversion : public OpConversionPattern { llvm::SmallVector caseDest = {op.getResumeDest(), op.getCleanupDest()}; rewriter.replaceOpWithNewOp( - op, rewriter.create(loc, i32, coroSuspend.getResult()), + op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()), /*defaultDestination=*/op.getSuspendDest(), /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, @@ -602,11 +602,11 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern { // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i64 - auto nullPtr = rewriter.create(loc, storagePtrType); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType); auto gep = - rewriter.create(loc, storagePtrType, storedType, + LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType, nullPtr, ArrayRef{1}); - return rewriter.create(loc, i64, gep); + return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep); }; rewriter.replaceOpWithNewOp(op, kCreateValue, resultType, @@ -739,7 +739,7 @@ class RuntimeAwaitOpLowering : public OpConversionPattern { .Case([](Type) { return kAwaitValue; }) .Case([](Type) { return kAwaitGroup; }); - rewriter.create(op->getLoc(), apiFuncName, TypeRange(), + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), adaptor.getOperands()); rewriter.eraseOp(op); @@ -772,11 +772,11 @@ class RuntimeAwaitAndResumeOpLowering // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); - auto resumePtr = rewriter.create( + auto resumePtr = LLVM::AddressOfOp::create(rewriter, op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); - rewriter.create( + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), ValueRange({operand, handle, resumePtr.getRes()})); rewriter.eraseOp(op); @@ -801,7 +801,7 @@ class RuntimeResumeOpLowering ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); - auto resumePtr = rewriter.create( + auto resumePtr = LLVM::AddressOfOp::create(rewriter, op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); @@ -832,7 +832,7 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern { // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create( + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. @@ -845,7 +845,7 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern { Value castedStoragePtr = storagePtr.getResult(0); // Store the yielded value into the async value storage. auto value = adaptor.getValue(); - rewriter.create(loc, value, castedStoragePtr); + LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr); // Erase the original runtime store operation. rewriter.eraseOp(op); @@ -872,7 +872,7 @@ class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern { // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create( + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. @@ -960,7 +960,7 @@ class RefCountingOpLowering : public OpConversionPattern { LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = rewriter.create( + auto count = arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(op.getCount())); diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index b9991f36cdaaf..a63a1ceb00ebb 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -47,26 +47,26 @@ struct CloneOpConversion : public OpConversionPattern { if (auto unrankedType = dyn_cast(type)) { // Constants - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); // Dynamically evaluate the size and shape of the unranked memref - Value rank = rewriter.create(loc, op.getInput()); + Value rank = memref::RankOp::create(rewriter, loc, op.getInput()); MemRefType allocType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); - Value shape = rewriter.create(loc, allocType, rank); + Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank); // Create a loop to query dimension sizes, store them as a shape, and // compute the total size of the memref auto loopBody = [&](OpBuilder &builder, Location loc, Value i, ValueRange args) { auto acc = args.front(); - auto dim = rewriter.create(loc, op.getInput(), i); + auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i); - rewriter.create(loc, dim, shape, i); - acc = rewriter.create(loc, acc, dim); + memref::StoreOp::create(rewriter, loc, dim, shape, i); + acc = arith::MulIOp::create(rewriter, loc, acc, dim); - rewriter.create(loc, acc); + scf::YieldOp::create(rewriter, loc, acc); }; auto size = rewriter .create(loc, zero, rank, one, ValueRange(one), @@ -78,9 +78,9 @@ struct CloneOpConversion : public OpConversionPattern { // Allocate new memref with 1D dynamic shape, then reshape into the // shape of the original unranked memref - alloc = rewriter.create(loc, memrefType, size); + alloc = memref::AllocOp::create(rewriter, loc, memrefType, size); alloc = - rewriter.create(loc, unrankedType, alloc, shape); + memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape); } else { MemRefType memrefType = cast(type); MemRefLayoutAttrInterface layout; @@ -103,14 +103,14 @@ struct CloneOpConversion : public OpConversionPattern { } // Allocate a memref with identity layout. - alloc = rewriter.create(loc, allocType, dynamicOperands); + alloc = memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands); // Cast the allocation to the specified type if needed. if (memrefType != allocType) alloc = - rewriter.create(op->getLoc(), memrefType, alloc); + memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc); } - rewriter.create(loc, op.getInput(), alloc); + memref::CopyOp::create(rewriter, loc, op.getInput(), alloc); rewriter.replaceOp(op, alloc); return success(); } diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp index 70b22386f1eea..1961e1004ba8e 100644 --- a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp +++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp @@ -23,41 +23,41 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic( ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm) { - Value rhsSqNorm = rewriter.create( - loc, rewriter.create(loc, rhsRe, rhsRe, fmf), - rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + Value rhsSqNorm = LLVM::FAddOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf); - Value realNumerator = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRe, fmf), - rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); + Value realNumerator = LLVM::FAddOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf); - Value imagNumerator = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + Value imagNumerator = LLVM::FSubOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); - *resultRe = rewriter.create(loc, realNumerator, rhsSqNorm, fmf); - *resultIm = rewriter.create(loc, imagNumerator, rhsSqNorm, fmf); + *resultRe = LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf); + *resultIm = LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf); } void mlir::complex::convertDivToStandardUsingAlgebraic( ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm) { - Value rhsSqNorm = rewriter.create( - loc, rewriter.create(loc, rhsRe, rhsRe, fmf), - rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + Value rhsSqNorm = arith::AddFOp::create(rewriter, + loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf); - Value realNumerator = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRe, fmf), - rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); - Value imagNumerator = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + Value realNumerator = arith::AddFOp::create(rewriter, + loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf); + Value imagNumerator = arith::SubFOp::create(rewriter, + loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); *resultRe = - rewriter.create(loc, realNumerator, rhsSqNorm, fmf); + arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf); *resultIm = - rewriter.create(loc, imagNumerator, rhsSqNorm, fmf); + arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf); } // Smith's algorithm to divide complex numbers. It is just a bit smarter @@ -94,180 +94,180 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction( auto elementType = cast(rhsRe.getType()); Value rhsRealImagRatio = - rewriter.create(loc, rhsRe, rhsIm, fmf); - Value rhsRealImagDenom = rewriter.create( + LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = LLVM::FAddOp::create(rewriter, loc, rhsIm, - rewriter.create(loc, rhsRealImagRatio, rhsRe, fmf), fmf); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRealImagRatio, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = LLVM::FAddOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm, fmf); Value resultReal1 = - rewriter.create(loc, realNumerator1, rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRealImagRatio, fmf), + LLVM::FDivOp::create(rewriter, loc, realNumerator1, rhsRealImagDenom, fmf); + Value imagNumerator1 = LLVM::FSubOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe, fmf); Value resultImag1 = - rewriter.create(loc, imagNumerator1, rhsRealImagDenom, fmf); + LLVM::FDivOp::create(rewriter, loc, imagNumerator1, rhsRealImagDenom, fmf); Value rhsImagRealRatio = - rewriter.create(loc, rhsIm, rhsRe, fmf); - Value rhsImagRealDenom = rewriter.create( + LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = LLVM::FAddOp::create(rewriter, loc, rhsRe, - rewriter.create(loc, rhsImagRealRatio, rhsIm, fmf), fmf); - Value realNumerator2 = rewriter.create( + LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = LLVM::FAddOp::create(rewriter, loc, lhsRe, - rewriter.create(loc, lhsIm, rhsImagRealRatio, fmf), fmf); + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf); Value resultReal2 = - rewriter.create(loc, realNumerator2, rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create( + LLVM::FDivOp::create(rewriter, loc, realNumerator2, rhsImagRealDenom, fmf); + Value imagNumerator2 = LLVM::FSubOp::create(rewriter, loc, lhsIm, - rewriter.create(loc, lhsRe, rhsImagRealRatio, fmf), fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf); Value resultImag2 = - rewriter.create(loc, imagNumerator2, rhsImagRealDenom, fmf); + LLVM::FDivOp::create(rewriter, loc, imagNumerator2, rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create( + Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsRe, fmf); - Value rhsRealIsZero = rewriter.create( + Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf); + Value rhsRealIsZero = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsIm, fmf); - Value rhsImagIsZero = rewriter.create( + Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf); + Value rhsImagIsZero = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero); Value lhsRealIsNotNaN = - rewriter.create(loc, LLVM::FCmpPredicate::ord, lhsRe, zero); + LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero); Value lhsImagIsNotNaN = - rewriter.create(loc, LLVM::FCmpPredicate::ord, lhsIm, zero); + LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( + LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = LLVM::AndOp::create(rewriter, loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( + LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = LLVM::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfrhsReal = - rewriter.create(loc, inf, rhsRe); + LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe); Value infinityResultReal = - rewriter.create(loc, infWithSignOfrhsReal, lhsRe, fmf); + LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf); Value infinityResultImag = - rewriter.create(loc, infWithSignOfrhsReal, lhsIm, fmf); + LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create( + Value rhsRealFinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create( + Value rhsImagFinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf); Value rhsFinite = - rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsRe, fmf); - Value lhsRealInfinite = rewriter.create( + LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf); + Value lhsRealInfinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsIm, fmf); - Value lhsImagInfinite = rewriter.create( + Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf); + Value lhsImagInfinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( + LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite); + Value one = LLVM::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsRealInfinite, one, zero), + Value lhsRealIsInfWithSign = LLVM::CopySignOp::create(rewriter, + loc, LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe); - Value lhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsImagInfinite, one, zero), + Value lhsImagIsInfWithSign = LLVM::CopySignOp::create(rewriter, + loc, LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm); Value lhsRealIsInfWithSignTimesrhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsRe, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf); Value lhsImagIsInfWithSignTimesrhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsIm, fmf); - Value resultReal3 = rewriter.create( + LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = LLVM::FMulOp::create(rewriter, loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesrhsReal, + LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal, lhsImagIsInfWithSignTimesrhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesrhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsIm, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf); Value lhsImagIsInfWithSignTimesrhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsRe, fmf); - Value resultImag3 = rewriter.create( + LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = LLVM::FMulOp::create(rewriter, loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesrhsReal, + LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal, lhsRealIsInfWithSignTimesrhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create( + Value lhsRealFinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create( + Value lhsImagFinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf); Value lhsFinite = - rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create( + LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create( + Value rhsImagInfinite = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsRealInfinite, one, zero), + LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = LLVM::CopySignOp::create(rewriter, + loc, LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe); - Value rhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsImagInfinite, one, zero), + Value rhsImagIsInfWithSign = LLVM::CopySignOp::create(rewriter, + loc, LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm); Value rhsRealIsInfWithSignTimeslhsReal = - rewriter.create(loc, lhsRe, rhsRealIsInfWithSign, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimeslhsImag = - rewriter.create(loc, lhsIm, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create( + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = LLVM::FMulOp::create(rewriter, loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimeslhsReal, + LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal, rhsImagIsInfWithSignTimeslhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimeslhsImag = - rewriter.create(loc, lhsIm, rhsRealIsInfWithSign, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimeslhsReal = - rewriter.create(loc, lhsRe, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create( + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = LLVM::FMulOp::create(rewriter, loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimeslhsImag, + LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag, rhsImagIsInfWithSignTimeslhsReal, fmf), fmf); - Value realAbsSmallerThanImagAbs = rewriter.create( + Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs); - Value resultReal5 = rewriter.create( + Value resultReal5 = LLVM::SelectOp::create(rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag5 = rewriter.create( + Value resultImag5 = LLVM::SelectOp::create(rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create( + Value resultRealSpecialCase3 = LLVM::SelectOp::create(rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5); - Value resultImagSpecialCase3 = rewriter.create( + Value resultImagSpecialCase3 = LLVM::SelectOp::create(rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5); - Value resultRealSpecialCase2 = rewriter.create( + Value resultRealSpecialCase2 = LLVM::SelectOp::create(rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create( + Value resultImagSpecialCase2 = LLVM::SelectOp::create(rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create( + Value resultRealSpecialCase1 = LLVM::SelectOp::create(rewriter, loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create( + Value resultImagSpecialCase1 = LLVM::SelectOp::create(rewriter, loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = rewriter.create( + Value resultRealIsNaN = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero); - Value resultImagIsNaN = rewriter.create( + Value resultImagIsNaN = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN); - *resultRe = rewriter.create( + *resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN, resultRealSpecialCase1, resultReal5); - *resultIm = rewriter.create( + *resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN, resultImagSpecialCase1, resultImag5); } @@ -278,179 +278,179 @@ void mlir::complex::convertDivToStandardUsingRangeReduction( auto elementType = cast(rhsRe.getType()); Value rhsRealImagRatio = - rewriter.create(loc, rhsRe, rhsIm, fmf); - Value rhsRealImagDenom = rewriter.create( + arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = arith::AddFOp::create(rewriter, loc, rhsIm, - rewriter.create(loc, rhsRealImagRatio, rhsRe, fmf), fmf); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsRe, rhsRealImagRatio, fmf), + arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = arith::AddFOp::create(rewriter, + loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm, fmf); - Value resultReal1 = rewriter.create(loc, realNumerator1, + Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1, rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRealImagRatio, fmf), + Value imagNumerator1 = arith::SubFOp::create(rewriter, + loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe, fmf); - Value resultImag1 = rewriter.create(loc, imagNumerator1, + Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1, rhsRealImagDenom, fmf); Value rhsImagRealRatio = - rewriter.create(loc, rhsIm, rhsRe, fmf); - Value rhsImagRealDenom = rewriter.create( + arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = arith::AddFOp::create(rewriter, loc, rhsRe, - rewriter.create(loc, rhsImagRealRatio, rhsIm, fmf), fmf); - Value realNumerator2 = rewriter.create( + arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = arith::AddFOp::create(rewriter, loc, lhsRe, - rewriter.create(loc, lhsIm, rhsImagRealRatio, fmf), fmf); - Value resultReal2 = rewriter.create(loc, realNumerator2, + arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2, rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create( + Value imagNumerator2 = arith::SubFOp::create(rewriter, loc, lhsIm, - rewriter.create(loc, lhsRe, rhsImagRealRatio, fmf), fmf); - Value resultImag2 = rewriter.create(loc, imagNumerator2, + arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2, rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsRe, fmf); - Value rhsRealIsZero = rewriter.create( + Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf); + Value rhsRealIsZero = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsIm, fmf); - Value rhsImagIsZero = rewriter.create( + Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf); + Value rhsImagIsZero = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = rewriter.create( + Value lhsRealIsNotNaN = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero); - Value lhsImagIsNotNaN = rewriter.create( + Value lhsImagIsNotNaN = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( + arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = arith::AndIOp::create(rewriter, loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( + arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = - rewriter.create(loc, inf, rhsRe); + math::CopySignOp::create(rewriter, loc, inf, rhsRe); Value infinityResultReal = - rewriter.create(loc, infWithSignOfRhsReal, lhsRe, fmf); + arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf); Value infinityResultImag = - rewriter.create(loc, infWithSignOfRhsReal, lhsIm, fmf); + arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create( + Value rhsRealFinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create( + Value rhsImagFinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = - rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsRe, fmf); - Value lhsRealInfinite = rewriter.create( + arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf); + Value lhsRealInfinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsIm, fmf); - Value lhsImagInfinite = rewriter.create( + Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf); + Value lhsImagInfinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( + arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite); + Value one = arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsRealInfinite, one, zero), + Value lhsRealIsInfWithSign = math::CopySignOp::create(rewriter, + loc, arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe); - Value lhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsImagInfinite, one, zero), + Value lhsImagIsInfWithSign = math::CopySignOp::create(rewriter, + loc, arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm); Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsRe, fmf); + arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf); Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsIm, fmf); - Value resultReal3 = rewriter.create( + arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = arith::MulFOp::create(rewriter, loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, + arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal, lhsImagIsInfWithSignTimesRhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsIm, fmf); + arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf); Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsRe, fmf); - Value resultImag3 = rewriter.create( + arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = arith::MulFOp::create(rewriter, loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, + arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal, lhsRealIsInfWithSignTimesRhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create( + Value lhsRealFinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create( + Value lhsImagFinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = - rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create( + arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create( + Value rhsImagInfinite = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsRealInfinite, one, zero), + arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = math::CopySignOp::create(rewriter, + loc, arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe); - Value rhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsImagInfinite, one, zero), + Value rhsImagIsInfWithSign = math::CopySignOp::create(rewriter, + loc, arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm); Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsRe, rhsRealIsInfWithSign, fmf); + arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsIm, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create( + arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = arith::MulFOp::create(rewriter, loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, + arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal, rhsImagIsInfWithSignTimesLhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsIm, rhsRealIsInfWithSign, fmf); + arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsRe, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create( + arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = arith::MulFOp::create(rewriter, loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, + arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag, rhsImagIsInfWithSignTimesLhsReal, fmf), fmf); - Value realAbsSmallerThanImagAbs = rewriter.create( + Value realAbsSmallerThanImagAbs = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); - Value resultReal5 = rewriter.create( + Value resultReal5 = arith::SelectOp::create(rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag5 = rewriter.create( + Value resultImag5 = arith::SelectOp::create(rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create( + Value resultRealSpecialCase3 = arith::SelectOp::create(rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5); - Value resultImagSpecialCase3 = rewriter.create( + Value resultImagSpecialCase3 = arith::SelectOp::create(rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5); - Value resultRealSpecialCase2 = rewriter.create( + Value resultRealSpecialCase2 = arith::SelectOp::create(rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create( + Value resultImagSpecialCase2 = arith::SelectOp::create(rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create( + Value resultRealSpecialCase1 = arith::SelectOp::create(rewriter, loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create( + Value resultImagSpecialCase1 = arith::SelectOp::create(rewriter, loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = rewriter.create( + Value resultRealIsNaN = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero); - Value resultImagIsNaN = rewriter.create( + Value resultImagIsNaN = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN); - *resultRe = rewriter.create( + *resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN, resultRealSpecialCase1, resultReal5); - *resultIm = rewriter.create( + *resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN, resultImagSpecialCase1, resultImag5); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index e5e862315941d..a06935c4387c6 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder, Location loc, Type type) { - Value val = builder.create(loc, type); + Value val = LLVM::PoisonOp::create(builder, loc, type); return ComplexStructBuilder(val); } @@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value sqNorm = rewriter.create( - loc, rewriter.create(loc, real, real, fmf), - rewriter.create(loc, imag, imag, fmf), fmf); + Value sqNorm = LLVM::FAddOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf), + LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); @@ -192,9 +192,9 @@ struct AddOpConversion : public ConvertOpToLLVMPattern { op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); + LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern { Value lhsRe = arg.lhs.real(); Value lhsIm = arg.lhs.imag(); - Value real = rewriter.create( - loc, rewriter.create(loc, rhsRe, lhsRe, fmf), - rewriter.create(loc, rhsIm, lhsIm, fmf), fmf); + Value real = LLVM::FSubOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf); - Value imag = rewriter.create( - loc, rewriter.create(loc, lhsIm, rhsRe, fmf), - rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + Value imag = LLVM::FAddOp::create(rewriter, + loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -314,9 +314,9 @@ struct SubOpConversion : public ConvertOpToLLVMPattern { op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); + LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp index 56269d189873a..4800f65c675e6 100644 --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -84,7 +84,7 @@ LogicalResult ScalarOpToLibmCall::matchAndRewrite( rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create(rewriter.getUnknownLoc(), name, + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0c832c452718b..d14e3888b118c 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -31,44 +31,44 @@ enum class AbsFn { abs, sqrt, rsqrt }; // Returns the absolute value, its square root or its reciprocal square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { - Value one = b.create(real.getType(), + Value one = arith::ConstantOp::create(b, real.getType(), b.getFloatAttr(real.getType(), 1.0)); - Value absReal = b.create(real, fmf); - Value absImag = b.create(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value max = b.create(absReal, absImag, fmf); - Value min = b.create(absReal, absImag, fmf); + Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf); + Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf); // The lowering below requires NaNs and infinities to work correctly. arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value ratio = b.create(min, max, fmfWithNaNInf); - Value ratioSq = b.create(ratio, ratio, fmfWithNaNInf); - Value ratioSqPlusOne = b.create(ratioSq, one, fmfWithNaNInf); + Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf); + Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf); + Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create(ratioSqPlusOne, fmfWithNaNInf); - min = b.create(min, fmfWithNaNInf); - max = b.create(max, fmfWithNaNInf); + ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + min = math::RsqrtOp::create(b, min, fmfWithNaNInf); + max = math::RsqrtOp::create(b, max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { - Value quarter = b.create( + Value quarter = arith::ConstantOp::create(b, real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. - Value sqrt = b.create(max, fmfWithNaNInf); - Value p025 = b.create(ratioSqPlusOne, quarter, fmfWithNaNInf); - result = b.create(sqrt, p025, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf); + Value p025 = math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf); + result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf); } else { - Value sqrt = b.create(ratioSqPlusOne, fmfWithNaNInf); - result = b.create(max, sqrt, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf); } - Value isNaN = b.create(arith::CmpFPredicate::UNO, result, + Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result, result, fmfWithNaNInf); - return b.create(isNaN, min, result); + return arith::SelectOp::create(b, isNaN, min, result); } struct AbsOpConversion : public OpConversionPattern { @@ -81,8 +81,8 @@ struct AbsOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); return success(); @@ -105,28 +105,28 @@ struct Atan2OpConversion : public OpConversionPattern { Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); - Value rhsSquared = b.create(type, rhs, rhs, fmf); - Value lhsSquared = b.create(type, lhs, lhs, fmf); + Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf); + Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf); Value rhsSquaredPlusLhsSquared = - b.create(type, rhsSquared, lhsSquared, fmf); + complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf); Value sqrtOfRhsSquaredPlusLhsSquared = - b.create(type, rhsSquaredPlusLhsSquared, fmf); + complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf); Value zero = - b.create(elementType, b.getZeroAttr(elementType)); - Value one = b.create(elementType, + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); + Value one = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 1)); - Value i = b.create(type, zero, one); - Value iTimesLhs = b.create(i, lhs, fmf); - Value rhsPlusILhs = b.create(rhs, iTimesLhs, fmf); + Value i = complex::CreateOp::create(b, type, zero, one); + Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf); + Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf); - Value divResult = b.create( + Value divResult = complex::DivOp::create(b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); - Value logResult = b.create(divResult, fmf); + Value logResult = complex::LogOp::create(b, divResult, fmf); - Value negativeOne = b.create( + Value negativeOne = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, -1)); - Value negativeI = b.create(type, zero, negativeOne); + Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne); rewriter.replaceOpWithNewOp(op, negativeI, logResult, fmf); return success(); @@ -146,14 +146,14 @@ struct ComparisonOpConversion : public OpConversionPattern { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()).getElementType(); - Value realLhs = rewriter.create(loc, type, adaptor.getLhs()); - Value imagLhs = rewriter.create(loc, type, adaptor.getLhs()); - Value realRhs = rewriter.create(loc, type, adaptor.getRhs()); - Value imagRhs = rewriter.create(loc, type, adaptor.getRhs()); + Value realLhs = complex::ReOp::create(rewriter, loc, type, adaptor.getLhs()); + Value imagLhs = complex::ImOp::create(rewriter, loc, type, adaptor.getLhs()); + Value realRhs = complex::ReOp::create(rewriter, loc, type, adaptor.getRhs()); + Value imagRhs = complex::ImOp::create(rewriter, loc, type, adaptor.getRhs()); Value realComparison = - rewriter.create(loc, p, realLhs, realRhs); + arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs); Value imagComparison = - rewriter.create(loc, p, imagLhs, imagRhs); + arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); @@ -176,13 +176,13 @@ struct BinaryComplexOpConversion : public OpConversionPattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value realLhs = b.create(elementType, adaptor.getLhs()); - Value realRhs = b.create(elementType, adaptor.getRhs()); - Value resultReal = b.create(elementType, realLhs, realRhs, + Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value resultReal = BinaryStandardOp::create(b, elementType, realLhs, realRhs, fmf.getValue()); - Value imagLhs = b.create(elementType, adaptor.getLhs()); - Value imagRhs = b.create(elementType, adaptor.getRhs()); - Value resultImag = b.create(elementType, imagLhs, imagRhs, + Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs()); + Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs, imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -205,20 +205,20 @@ struct TrigonometricOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. - Value half = rewriter.create( + Value half = arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); - Value exp = rewriter.create(loc, imag, fmf); - Value scaledExp = rewriter.create(loc, half, exp, fmf); - Value reciprocalExp = rewriter.create(loc, half, exp, fmf); - Value sin = rewriter.create(loc, real, fmf); - Value cos = rewriter.create(loc, real, fmf); + Value exp = math::ExpOp::create(rewriter, loc, imag, fmf); + Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf); + Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf); + Value sin = math::SinOp::create(rewriter, loc, real, fmf); + Value cos = math::CosOp::create(rewriter, loc, real, fmf); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); @@ -251,11 +251,11 @@ struct CosOpConversion : public TrigonometricOpConversion { // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = - rewriter.create(loc, reciprocalExp, scaledExp, fmf); - Value resultReal = rewriter.create(loc, sum, cos, fmf); + arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf); Value diff = - rewriter.create(loc, reciprocalExp, scaledExp, fmf); - Value resultImag = rewriter.create(loc, diff, sin, fmf); + arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf); return {resultReal, resultImag}; } }; @@ -275,13 +275,13 @@ struct DivOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhsReal = - rewriter.create(loc, elementType, adaptor.getLhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value lhsImag = - rewriter.create(loc, elementType, adaptor.getLhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value rhsReal = - rewriter.create(loc, elementType, adaptor.getRhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value rhsImag = - rewriter.create(loc, elementType, adaptor.getRhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value resultReal, resultImag; @@ -318,16 +318,16 @@ struct ExpOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); - Value expReal = rewriter.create(loc, real, fmf.getValue()); - Value cosImag = rewriter.create(loc, imag, fmf.getValue()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); + Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); Value resultReal = - rewriter.create(loc, expReal, cosImag, fmf.getValue()); - Value sinImag = rewriter.create(loc, imag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); + Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); Value resultImag = - rewriter.create(loc, expReal, sinImag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -340,11 +340,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, arith::FastMathFlagsAttr fmf) { auto argType = mlir::cast(arg.getType()); Value poly = - b.create(b.getFloatAttr(argType, coefficients[0])); + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0])); for (unsigned i = 1; i < coefficients.size(); ++i) { - poly = b.create( + poly = math::FmaOp::create(b, poly, arg, - b.create(b.getFloatAttr(argType, coefficients[i])), + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])), fmf); } return poly; @@ -365,26 +365,26 @@ struct Expm1OpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value zero = b.create(b.getFloatAttr(elemType, 0.0)); - Value one = b.create(b.getFloatAttr(elemType, 1.0)); + Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0)); + Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0)); - Value expm1Real = b.create(real, fmf); - Value expReal = b.create(expm1Real, one, fmf); + Value expm1Real = math::ExpM1Op::create(b, real, fmf); + Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf); - Value sinImag = b.create(imag, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); Value cosm1Imag = emitCosm1(imag, fmf, b); - Value cosImag = b.create(cosm1Imag, one, fmf); + Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf); - Value realResult = b.create( - b.create(expm1Real, cosImag, fmf), cosm1Imag, fmf); + Value realResult = arith::AddFOp::create(b, + arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf); - Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, + Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf.getValue()); - Value imagResult = b.create( - imagIsZero, zero, b.create(expReal, sinImag, fmf)); + Value imagResult = arith::SelectOp::create(b, + imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf)); rewriter.replaceOpWithNewOp(op, type, realResult, imagResult); @@ -395,8 +395,8 @@ struct Expm1OpConversion : public OpConversionPattern { Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, ImplicitLocOpBuilder &b) const { auto argType = mlir::cast(arg.getType()); - auto negHalf = b.create(b.getFloatAttr(argType, -0.5)); - auto negOne = b.create(b.getFloatAttr(argType, -1.0)); + auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5)); + auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0)); // Algorithm copied from cephes cosm1. SmallVector kCoeffs{ @@ -405,23 +405,23 @@ struct Expm1OpConversion : public OpConversionPattern { 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; - Value cos = b.create(arg, fmf); - Value forLargeArg = b.create(cos, negOne, fmf); + Value cos = math::CosOp::create(b, arg, fmf); + Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf); - Value argPow2 = b.create(arg, arg, fmf); - Value argPow4 = b.create(argPow2, argPow2, fmf); + Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf); + Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf); Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); auto forSmallArg = - b.create(b.create(argPow4, poly, fmf), - b.create(negHalf, argPow2, fmf)); + arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf), + arith::MulFOp::create(b, negHalf, argPow2, fmf)); // (pi/4)^2 is approximately 0.61685 Value piOver4Pow2 = - b.create(b.getFloatAttr(argType, 0.61685)); - Value cond = b.create(arith::CmpFPredicate::OGE, argPow2, + arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685)); + Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2, piOver4Pow2, fmf.getValue()); - return b.create(cond, forLargeArg, forSmallArg); + return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg); } }; @@ -436,13 +436,13 @@ struct LogOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create(elementType, adaptor.getComplex(), + Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf.getValue()); - Value resultReal = b.create(elementType, abs, fmf.getValue()); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value resultImag = - b.create(elementType, imag, real, fmf.getValue()); + math::Atan2Op::create(b, elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); @@ -460,40 +460,40 @@ struct Log1pOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value half = b.create(elementType, + Value half = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 0.5)); - Value one = b.create(elementType, + Value one = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create(real, one, fmf); - Value absRealPlusOne = b.create(realPlusOne, fmf); - Value absImag = b.create(imag, fmf); + Value realPlusOne = arith::AddFOp::create(b, real, one, fmf); + Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value maxAbs = b.create(absRealPlusOne, absImag, fmf); - Value minAbs = b.create(absRealPlusOne, absImag, fmf); + Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf); + Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf); - Value useReal = b.create(arith::CmpFPredicate::OGT, + Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, realPlusOne, absImag, fmf); - Value maxMinusOne = b.create(maxAbs, one, fmf); + Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = - b.create(useReal, real, maxMinusOne); + arith::SelectOp::create(b, useReal, real, maxMinusOne); arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value minMaxRatio = b.create(minAbs, maxAbs, fmfWithNaNInf); + Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = - b.create(maxAbsOfRealPlusOneAndImagMinusOne, fmf); - Value logOfSqrtPart = b.create( - b.create(minMaxRatio, minMaxRatio, fmfWithNaNInf), + math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf); + Value logOfSqrtPart = math::Log1pOp::create(b, + arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf), fmfWithNaNInf); - Value r = b.create( - b.create(half, logOfSqrtPart, fmfWithNaNInf), + Value r = arith::AddFOp::create(b, + arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf), logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); - Value resultReal = b.create( - b.create(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), + Value resultReal = arith::SelectOp::create(b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), minAbs, r); - Value resultImag = b.create(imag, realPlusOne, fmf); + Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); @@ -511,21 +511,21 @@ struct MulOpConversion : public OpConversionPattern { auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); - Value lhsReal = b.create(elementType, adaptor.getLhs()); - Value lhsImag = b.create(elementType, adaptor.getLhs()); - Value rhsReal = b.create(elementType, adaptor.getRhs()); - Value rhsImag = b.create(elementType, adaptor.getRhs()); + Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs()); Value lhsRealTimesRhsReal = - b.create(lhsReal, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue); Value lhsImagTimesRhsImag = - b.create(lhsImag, rhsImag, fmfValue); - Value real = b.create(lhsRealTimesRhsReal, + arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue); + Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal, lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = - b.create(lhsImag, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue); Value lhsRealTimesRhsImag = - b.create(lhsReal, rhsImag, fmfValue); - Value imag = b.create(lhsImagTimesRhsReal, + arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue); + Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal, lhsRealTimesRhsImag, fmfValue); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); @@ -543,11 +543,11 @@ struct NegOpConversion : public OpConversionPattern { auto elementType = cast(type.getElementType()); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); - Value negReal = rewriter.create(loc, real); - Value negImag = rewriter.create(loc, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negReal = arith::NegFOp::create(rewriter, loc, real); + Value negImag = arith::NegFOp::create(rewriter, loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } @@ -570,11 +570,11 @@ struct SinOpConversion : public TrigonometricOpConversion { // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = - rewriter.create(loc, scaledExp, reciprocalExp, fmf); - Value resultReal = rewriter.create(loc, sum, sin, fmf); + arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf); Value diff = - rewriter.create(loc, scaledExp, reciprocalExp, fmf); - Value resultImag = rewriter.create(loc, diff, cos, fmf); + arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf); return {resultReal, resultImag}; } }; @@ -593,64 +593,64 @@ struct SqrtOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create(elementType, + return arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); - Value half = b.create(elementType, + Value half = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 0.5)); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); - Value argArg = b.create(imag, real, fmf); - Value sqrtArg = b.create(argArg, half, fmf); - Value cos = b.create(sqrtArg, fmf); - Value sin = b.create(sqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf); + Value cos = math::CosOp::create(b, sqrtArg, fmf); + Value sin = math::SinOp::create(b, sqrtArg, fmf); // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply // 0 * inf. Value sinIsZero = - b.create(arith::CmpFPredicate::OEQ, sin, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf); - Value resultReal = b.create(absSqrt, cos, fmf); - Value resultImag = b.create( - sinIsZero, zero, b.create(absSqrt, sin, fmf)); + Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf); + Value resultImag = arith::SelectOp::create(b, + sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf)); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value inf = cst(APFloat::getInf(floatSemantics)); Value negInf = cst(APFloat::getInf(floatSemantics, true)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absImag = b.create(elementType, imag, fmf); + Value absImag = math::AbsFOp::create(b, elementType, imag, fmf); Value absImagIsInf = - b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absImag, inf, fmf); Value absImagIsNotInf = - b.create(arith::CmpFPredicate::ONE, absImag, inf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::ONE, absImag, inf, fmf); Value realIsInf = - b.create(arith::CmpFPredicate::OEQ, real, inf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf); Value realIsNegInf = - b.create(arith::CmpFPredicate::OEQ, real, negInf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, negInf, fmf); - resultReal = b.create( - b.create(realIsNegInf, absImagIsNotInf), zero, + resultReal = arith::SelectOp::create(b, + arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero, resultReal); - resultReal = b.create( - b.create(absImagIsInf, realIsInf), inf, resultReal); + resultReal = arith::SelectOp::create(b, + arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal); - Value imagSignInf = b.create(inf, imag, fmf); - resultImag = b.create( - b.create(arith::CmpFPredicate::UNO, absSqrt, absSqrt), + Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf); + resultImag = arith::SelectOp::create(b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt), nan, resultImag); - resultImag = b.create( - b.create(absImagIsInf, realIsNegInf), imagSignInf, + resultImag = arith::SelectOp::create(b, + arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf, resultImag); } Value resultIsZero = - b.create(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); - resultReal = b.create(resultIsZero, zero, resultReal); - resultImag = b.create(resultIsZero, zero, resultImag); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); + resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -669,19 +669,19 @@ struct SignOpConversion : public OpConversionPattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value zero = - b.create(elementType, b.getZeroAttr(elementType)); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); Value realIsZero = - b.create(arith::CmpFPredicate::OEQ, real, zero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create(realIsZero, imagIsZero); - auto abs = b.create(elementType, adaptor.getComplex(), fmf); - Value realSign = b.create(real, abs, fmf); - Value imagSign = b.create(imag, abs, fmf); - Value sign = b.create(type, realSign, imagSign); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero); + auto abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf); + Value realSign = arith::DivFOp::create(b, real, abs, fmf); + Value imagSign = arith::DivFOp::create(b, imag, abs, fmf); + Value sign = complex::CreateOp::create(b, type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), sign); return success(); @@ -703,84 +703,84 @@ struct TanTanhOpConversion : public OpConversionPattern { const auto &floatSemantics = elementType.getFloatSemantics(); Value real = - b.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(b, loc, elementType, adaptor.getComplex()); Value imag = - b.create(loc, elementType, adaptor.getComplex()); - Value negOne = b.create( + complex::ImOp::create(b, loc, elementType, adaptor.getComplex()); + Value negOne = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, -1.0)); if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(real, imag); - real = b.create(real, negOne, fmf); + real = arith::MulFOp::create(b, real, negOne, fmf); } auto cst = [&](APFloat v) { - return b.create(elementType, + return arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, v)); }; Value inf = cst(APFloat::getInf(floatSemantics)); - Value four = b.create(elementType, + Value four = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 4.0)); - Value twoReal = b.create(real, real, fmf); - Value negTwoReal = b.create(negOne, twoReal, fmf); + Value twoReal = arith::AddFOp::create(b, real, real, fmf); + Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf); - Value expTwoRealMinusOne = b.create(twoReal, fmf); - Value expNegTwoRealMinusOne = b.create(negTwoReal, fmf); + Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf); + Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf); Value realNum = - b.create(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); + arith::SubFOp::create(b, expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); - Value cosImag = b.create(imag, fmf); - Value cosImagSq = b.create(cosImag, cosImag, fmf); - Value twoCosTwoImagPlusOne = b.create(cosImagSq, four, fmf); - Value sinImag = b.create(imag, fmf); + Value cosImag = math::CosOp::create(b, imag, fmf); + Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf); + Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); - Value imagNum = b.create( - four, b.create(cosImag, sinImag, fmf), fmf); + Value imagNum = arith::MulFOp::create(b, + four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf); Value expSumMinusTwo = - b.create(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); + arith::AddFOp::create(b, expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); Value denom = - b.create(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); + arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf); - Value isInf = b.create(arith::CmpFPredicate::OEQ, + Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, expSumMinusTwo, inf, fmf); - Value realLimit = b.create(negOne, real, fmf); + Value realLimit = math::CopySignOp::create(b, negOne, real, fmf); - Value resultReal = b.create( - isInf, realLimit, b.create(realNum, denom, fmf)); - Value resultImag = b.create(imagNum, denom, fmf); + Value resultReal = arith::SelectOp::create(b, + isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf)); + Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value absReal = b.create(real, fmf); - Value zero = b.create( + Value absReal = math::AbsFOp::create(b, real, fmf); + Value zero = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 0.0)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value absRealIsInf = - b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absReal, inf, fmf); Value imagIsZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value absRealIsNotInf = b.create( - absRealIsInf, b.create(true, /*width=*/1)); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value absRealIsNotInf = arith::XOrIOp::create(b, + absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1)); - Value imagNumIsNaN = b.create(arith::CmpFPredicate::UNO, + Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, imagNum, imagNum, fmf); Value resultRealIsNaN = - b.create(imagNumIsNaN, absRealIsNotInf); - Value resultImagIsZero = b.create( - imagIsZero, b.create(absRealIsInf, imagNumIsNaN)); + arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf); + Value resultImagIsZero = arith::OrIOp::create(b, + imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN)); - resultReal = b.create(resultRealIsNaN, nan, resultReal); + resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal); resultImag = - b.create(resultImagIsZero, zero, resultImag); + arith::SelectOp::create(b, resultImagIsZero, zero, resultImag); } if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(resultReal, resultImag); - resultImag = b.create(resultImag, negOne, fmf); + resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf); } rewriter.replaceOpWithNewOp(op, type, resultReal, @@ -799,10 +799,10 @@ struct ConjOpConversion : public OpConversionPattern { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = - rewriter.create(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create(loc, elementType, adaptor.getComplex()); - Value negImag = rewriter.create(loc, elementType, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag); rewriter.replaceOpWithNewOp(op, type, real, negImag); @@ -818,97 +818,97 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, arith::FastMathFlags fmf) { auto elementType = cast(type.getElementType()); - Value a = builder.create(lhs); - Value b = builder.create(lhs); + Value a = complex::ReOp::create(builder, lhs); + Value b = complex::ImOp::create(builder, lhs); - Value abs = builder.create(lhs, fmf); - Value absToC = builder.create(abs, c, fmf); + Value abs = complex::AbsOp::create(builder, lhs, fmf); + Value absToC = math::PowFOp::create(builder, abs, c, fmf); - Value negD = builder.create(d, fmf); - Value argLhs = builder.create(b, a, fmf); - Value negDArgLhs = builder.create(negD, argLhs, fmf); - Value expNegDArgLhs = builder.create(negDArgLhs, fmf); + Value negD = arith::NegFOp::create(builder, d, fmf); + Value argLhs = math::Atan2Op::create(builder, b, a, fmf); + Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf); + Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf); - Value coeff = builder.create(absToC, expNegDArgLhs, fmf); - Value lnAbs = builder.create(abs, fmf); - Value cArgLhs = builder.create(c, argLhs, fmf); - Value dLnAbs = builder.create(d, lnAbs, fmf); - Value q = builder.create(cArgLhs, dLnAbs, fmf); - Value cosQ = builder.create(q, fmf); - Value sinQ = builder.create(q, fmf); + Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf); + Value lnAbs = math::LogOp::create(builder, abs, fmf); + Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf); + Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf); + Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf); + Value cosQ = math::CosOp::create(builder, q, fmf); + Value sinQ = math::SinOp::create(builder, q, fmf); - Value inf = builder.create( + Value inf = arith::ConstantOp::create(builder, elementType, builder.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value zero = builder.create( + Value zero = arith::ConstantOp::create(builder, elementType, builder.getFloatAttr(elementType, 0.0)); - Value one = builder.create( + Value one = arith::ConstantOp::create(builder, elementType, builder.getFloatAttr(elementType, 1.0)); - Value complexOne = builder.create(type, one, zero); - Value complexZero = builder.create(type, zero, zero); - Value complexInf = builder.create(type, inf, zero); + Value complexOne = complex::CreateOp::create(builder, type, one, zero); + Value complexZero = complex::CreateOp::create(builder, type, zero, zero); + Value complexInf = complex::CreateOp::create(builder, type, inf, zero); // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value absEqZero = - builder.create(arith::CmpFPredicate::OEQ, abs, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf); Value dEqZero = - builder.create(arith::CmpFPredicate::OEQ, d, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf); Value cEqZero = - builder.create(arith::CmpFPredicate::OEQ, c, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf); Value bEqZero = - builder.create(arith::CmpFPredicate::OEQ, b, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf); Value zeroLeC = - builder.create(arith::CmpFPredicate::OLE, zero, c, fmf); - Value coeffCosQ = builder.create(coeff, cosQ, fmf); - Value coeffSinQ = builder.create(coeff, sinQ, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf); + Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf); + Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf); Value complexOneOrZero = - builder.create(cEqZero, complexOne, complexZero); + arith::SelectOp::create(builder, cEqZero, complexOne, complexZero); Value coeffCosSin = - builder.create(type, coeffCosQ, coeffSinQ); - Value cutoff0 = builder.create( - builder.create( - builder.create(absEqZero, dEqZero), zeroLeC), + complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ); + Value cutoff0 = arith::SelectOp::create(builder, + arith::AndIOp::create(builder, + arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC), complexOneOrZero, coeffCosSin); // Case 1: // x^0 is defined to be 1 for any x, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. - Value rhsEqZero = builder.create(cEqZero, dEqZero); + Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero); Value cutoff1 = - builder.create(rhsEqZero, complexOne, cutoff0); + arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0); // Case 2: // 1^(c + d*i) = 1 + 0*i - Value lhsEqOne = builder.create( - builder.create(arith::CmpFPredicate::OEQ, a, one, fmf), + Value lhsEqOne = arith::AndIOp::create(builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf), bEqZero); Value cutoff2 = - builder.create(lhsEqOne, complexOne, cutoff1); + arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1); // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 - Value lhsEqInf = builder.create( - builder.create(arith::CmpFPredicate::OEQ, a, inf, fmf), + Value lhsEqInf = arith::AndIOp::create(builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf), bEqZero); - Value rhsGt0 = builder.create( + Value rhsGt0 = arith::AndIOp::create(builder, dEqZero, - builder.create(arith::CmpFPredicate::OGT, c, zero, fmf)); - Value cutoff3 = builder.create( - builder.create(lhsEqInf, rhsGt0), complexInf, cutoff2); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf)); + Value cutoff3 = arith::SelectOp::create(builder, + arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf, cutoff2); // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 - Value rhsLt0 = builder.create( + Value rhsLt0 = arith::AndIOp::create(builder, dEqZero, - builder.create(arith::CmpFPredicate::OLT, c, zero, fmf)); - Value cutoff4 = builder.create( - builder.create(lhsEqInf, rhsLt0), complexZero, cutoff3); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf)); + Value cutoff4 = arith::SelectOp::create(builder, + arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero, cutoff3); return cutoff4; } @@ -923,8 +923,8 @@ struct PowOpConversion : public OpConversionPattern { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); - Value c = builder.create(elementType, adaptor.getRhs()); - Value d = builder.create(elementType, adaptor.getRhs()); + Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs()); + Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), c, d, op.getFastmath())}); @@ -945,64 +945,64 @@ struct RsqrtOpConversion : public OpConversionPattern { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create(elementType, + return arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value inf = cst(APFloat::getInf(floatSemantics)); - Value negHalf = b.create( + Value negHalf = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, -0.5)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); - Value argArg = b.create(imag, real, fmf); - Value rsqrtArg = b.create(argArg, negHalf, fmf); - Value cos = b.create(rsqrtArg, fmf); - Value sin = b.create(rsqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf); + Value cos = math::CosOp::create(b, rsqrtArg, fmf); + Value sin = math::SinOp::create(b, rsqrtArg, fmf); - Value resultReal = b.create(absRsqrt, cos, fmf); - Value resultImag = b.create(absRsqrt, sin, fmf); + Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf); + Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value negOne = b.create( + Value negOne = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, -1)); - Value realSignedZero = b.create(zero, real, fmf); - Value imagSignedZero = b.create(zero, imag, fmf); + Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf); + Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf); Value negImagSignedZero = - b.create(negOne, imagSignedZero, fmf); + arith::MulFOp::create(b, negOne, imagSignedZero, fmf); - Value absReal = b.create(real, fmf); - Value absImag = b.create(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); Value absImagIsInf = - b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absImag, inf, fmf); Value realIsNan = - b.create(arith::CmpFPredicate::UNO, real, real, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf); Value realIsInf = - b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); - Value inIsNanInf = b.create(absImagIsInf, realIsNan); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absReal, inf, fmf); + Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan); - Value resultIsZero = b.create(inIsNanInf, realIsInf); + Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf); resultReal = - b.create(resultIsZero, realSignedZero, resultReal); - resultImag = b.create(resultIsZero, negImagSignedZero, + arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero, resultImag); } Value isRealZero = - b.create(arith::CmpFPredicate::OEQ, real, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf); Value isImagZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value isZero = b.create(isRealZero, isImagZero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero); - resultReal = b.create(isZero, inf, resultReal); - resultImag = b.create(isZero, nan, resultImag); + resultReal = arith::SelectOp::create(b, isZero, inf, resultReal); + resultImag = arith::SelectOp::create(b, isZero, nan, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -1021,9 +1021,9 @@ struct AngleOpConversion : public OpConversionPattern { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create(loc, type, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, type, adaptor.getComplex()); Value imag = - rewriter.create(loc, type, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, imag, real, fmf); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 88a8b7fb185c5..6a1f7c105b2c6 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create(rewriter.getUnknownLoc(), + abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(), "abort", abortFuncTy); } - rewriter.create(loc, abortFunc, ValueRange()); - rewriter.create(loc); + LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange()); + LLVM::UnreachableOp::create(rewriter, loc); } else { - rewriter.create(loc, ValueRange(), continuationBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock); } // Generate assertion test. diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index 9831dcaaaccc8..aa468c29536c7 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -33,7 +33,7 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( MutableArrayRef regions) { if (auto condBrOp = dyn_cast(controlFlowCondOp)) { assert(regions.size() == 2); - auto ifOp = builder.create(controlFlowCondOp->getLoc(), + auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(), resultTypes, condBrOp.getCondition()); ifOp.getThenRegion().takeBody(regions[0]); ifOp.getElseRegion().takeBody(regions[1]); @@ -43,7 +43,7 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( if (auto switchOp = dyn_cast(controlFlowCondOp)) { // `getCFGSwitchValue` returns an i32 that we need to convert to index // fist. - auto cast = builder.create( + auto cast = arith::IndexCastUIOp::create(builder, controlFlowCondOp->getLoc(), builder.getIndexType(), switchOp.getFlag()); SmallVector cases; @@ -55,7 +55,7 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( assert(regions.size() == cases.size() + 1); - auto indexSwitchOp = builder.create( + auto indexSwitchOp = scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); indexSwitchOp.getDefaultRegion().takeBody(regions[0]); @@ -75,7 +75,7 @@ LogicalResult ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) { - builder.create(loc, results); + scf::YieldOp::create(builder, loc, results); return success(); } @@ -84,7 +84,7 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { Location loc = replacedOp->getLoc(); - auto whileOp = builder.create(loc, loopVariablesInit.getTypes(), + auto whileOp = scf::WhileOp::create(builder, loc, loopVariablesInit.getTypes(), loopVariablesInit); whileOp.getBefore().takeBody(loopBody); @@ -92,15 +92,15 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( builder.setInsertionPointToEnd(&whileOp.getBefore().back()); // `getCFGSwitchValue` returns a i32. We therefore need to truncate the // condition to i1 first. It is guaranteed to be either 0 or 1 already. - builder.create( - loc, builder.create(loc, builder.getI1Type(), condition), + scf::ConditionOp::create(builder, + loc, arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition), loopVariablesNextIter); Block *afterBlock = builder.createBlock(&whileOp.getAfter()); afterBlock->addArguments( loopVariablesInit.getTypes(), SmallVector(loopVariablesInit.size(), loc)); - builder.create(loc, afterBlock->getArguments()); + scf::YieldOp::create(builder, loc, afterBlock->getArguments()); return whileOp.getOperation(); } @@ -108,7 +108,7 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned int value) { - return builder.create(loc, + return arith::ConstantOp::create(builder, loc, builder.getI32IntegerAttr(value)); } @@ -117,7 +117,7 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp( ArrayRef caseValues, BlockRange caseDestinations, ArrayRef caseArguments, Block *defaultDest, ValueRange defaultArgs) { - builder.create(loc, flag, defaultDest, defaultArgs, + cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs, llvm::to_vector_of(caseValues), caseDestinations, caseArguments); } @@ -125,7 +125,7 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp( Value ControlFlowToSCFTransformation::getUndefValue(Location loc, OpBuilder &builder, Type type) { - return builder.create(loc, type, nullptr); + return ub::PoisonOp::create(builder, loc, type, nullptr); } FailureOr diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index f8dc06f41ab87..20167f168185a 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -99,7 +99,7 @@ class FuncOpConversion final : public OpConversionPattern { } // Create the converted `emitc.func` op. - emitc::FuncOp newFuncOp = rewriter.create( + emitc::FuncOp newFuncOp = emitc::FuncOp::create(rewriter, funcOp.getLoc(), funcOp.getName(), FunctionType::get(rewriter.getContext(), signatureConverter.getConvertedTypes(), diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 36235636d6ba2..51dbe6cbb7de3 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -115,7 +115,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, SmallVector attributes; filterFuncAttributes(funcOp, attributes); - auto wrapperFuncOp = rewriter.create( + auto wrapperFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -129,13 +129,13 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); if (auto memrefType = dyn_cast(argType)) { - Value loaded = rewriter.create( + Value loaded = LLVM::LoadOp::create(rewriter, loc, typeConverter.convertType(memrefType), arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (isa(argType)) { - Value loaded = rewriter.create( + Value loaded = LLVM::LoadOp::create(rewriter, loc, typeConverter.convertType(argType), arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; @@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, args.push_back(arg); } - auto call = rewriter.create(loc, newFuncOp, args); + auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args); if (resultStructType) { - rewriter.create(loc, call.getResult(), + LLVM::StoreOp::create(rewriter, loc, call.getResult(), wrapperFuncOp.getArgument(0)); - rewriter.create(loc, ValueRange{}); + LLVM::ReturnOp::create(rewriter, loc, ValueRange{}); } else { - rewriter.create(loc, call.getResults()); + LLVM::ReturnOp::create(rewriter, loc, call.getResults()); } } @@ -182,7 +182,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, filterFuncAttributes(funcOp, attributes); // Create the auxiliary function. - auto wrapperFunc = builder.create( + auto wrapperFunc = LLVM::LLVMFuncOp::create(builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, if (resultStructType) { // Allocate the struct on the stack and pass the pointer. Type resultType = cast(wrapperType).getParamType(0); - Value one = builder.create( + Value one = LLVM::ConstantOp::create(builder, loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value result = - builder.create(loc, resultType, resultStructType, one); + LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one); args.push_back(result); } @@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); - Value one = builder.create( + Value one = LLVM::ConstantOp::create(builder, loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); - Value allocated = builder.create( + Value allocated = LLVM::AllocaOp::create(builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0); - builder.create(loc, packed, allocated); + LLVM::StoreOp::create(builder, loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; @@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); - auto call = builder.create(loc, wrapperFunc, args); + auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args); if (resultStructType) { Value result = - builder.create(loc, resultStructType, args.front()); - builder.create(loc, result); + LLVM::LoadOp::create(builder, loc, resultStructType, args.front()); + LLVM::ReturnOp::create(builder, loc, result); } else { - builder.create(loc, call.getResults()); + LLVM::ReturnOp::create(builder, loc, call.getResults()); } } @@ -283,7 +283,7 @@ static void restoreByValRefArgumentType( Type resTy = typeConverter.convertType( cast(byValRefAttr->getValue()).getValue()); - Value valueArg = rewriter.create(arg.getLoc(), resTy, arg); + Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); rewriter.replaceUsesOfBlockArgument(arg, valueArg); } } @@ -357,7 +357,7 @@ FailureOr mlir::convertFuncOpToLLVMFuncOp( symbolTable.remove(funcOp); } - auto newFuncOp = rewriter.create( + auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "failed to convert result type"); auto newOp = - rewriter.create(op.getLoc(), type, op.getValue()); + LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue()); for (const NamedAttribute &attr : op->getAttrs()) { if (attr.getName().strref() == "value") continue; @@ -556,7 +556,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter, useBarePtrCallConv); - auto newOp = rewriter.create( + auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp->getAttrs()); @@ -573,7 +573,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( + results.push_back(LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(), newOp->getResult(0), i)); } } @@ -726,9 +726,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 01ca5e99a9aff..f5d5460412da2 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, if (!(ret = moduleOp.template lookupSymbol(name))) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); - ret = b.create(loc, name, type, LLVM::Linkage::External); + ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } return ret; } @@ -68,7 +68,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); - return b.create(loc, globalType, + return LLVM::GlobalOp::create(b, loc, globalType, /*isConstant=*/true, LLVM::Linkage::Internal, name, attr, alignment, addrSpace); } @@ -151,7 +151,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, gpuFuncOp.getWorkgroupAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - auto globalOp = rewriter.create( + auto globalOp = LLVM::GlobalOp::create(rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, workgroupAddrSpace); @@ -220,7 +220,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, LLVM::CConv callingConvention = gpuFuncOp.isKernel() ? kernelCallingConvention : nonKernelCallingConvention; - auto llvmFuncOp = rewriter.create( + auto llvmFuncOp = LLVM::LLVMFuncOp::create(rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, /*comdat=*/nullptr, attributes); @@ -266,10 +266,10 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); - Value address = rewriter.create( + Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType, global.getSymNameAttr()); Value memory = - rewriter.create(loc, ptrType, global.getType(), + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(), address, ArrayRef{0, 0}); // Build a memref descriptor pointing to the buffer to plug with the @@ -298,14 +298,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, Type elementType = typeConverter->convertType(type.getElementType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); - Value numElements = rewriter.create( + Value numElements = LLVM::ConstantOp::create(rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null(gpuFuncOp.getPrivateAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - Value allocated = rewriter.create( + Value allocated = LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); @@ -418,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall - Value zeroI64 = rewriter.create(loc, llvmI64, 0); - auto printfBeginCall = rewriter.create(loc, ocklBegin, zeroI64); + Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0); + auto printfBeginCall = LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. @@ -427,20 +427,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() - Value globalPtr = rewriter.create( + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create(loc, ptrType, global.getGlobalType(), + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); - Value stringLen = rewriter.create( + Value stringLen = LLVM::ConstantOp::create(rewriter, loc, llvmI64, cast(global.getValueAttr()).size()); - Value oneI32 = rewriter.create(loc, llvmI32, 1); - Value zeroI32 = rewriter.create(loc, llvmI32, 0); + Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1); + Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0); - auto appendFormatCall = rewriter.create( + auto appendFormatCall = LLVM::CallOp::create(rewriter, loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.getArgs().empty() ? oneI32 : zeroI32}); @@ -456,17 +456,17 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( SmallVector arguments; arguments.push_back(printfDesc); arguments.push_back( - rewriter.create(loc, llvmI32, numArgsThisCall)); + LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; if (auto floatType = dyn_cast(arg.getType())) { if (!floatType.isF64()) - arg = rewriter.create( + arg = LLVM::FPExtOp::create(rewriter, loc, typeConverter->convertType(rewriter.getF64Type()), arg); - arg = rewriter.create(loc, llvmI64, arg); + arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) - arg = rewriter.create(loc, llvmI64, arg); + arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg); arguments.push_back(arg); } @@ -477,7 +477,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); - auto call = rewriter.create(loc, ocklAppendArgs, arguments); + auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments); printfDesc = call.getResult(); } rewriter.eraseOp(gpuPrintfOp); @@ -510,12 +510,12 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create( + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create(loc, ptrType, global.getGlobalType(), + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); // Construct arguments and function call @@ -525,7 +525,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - rewriter.create(loc, printfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -559,9 +559,9 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create(loc, global); + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global); Value stringStart = - rewriter.create(loc, ptrType, global.getGlobalType(), + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); SmallVector types; SmallVector args; @@ -572,27 +572,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( assert(type.isIntOrFloat()); if (isa(type)) { type = rewriter.getF64Type(); - promotedArg = rewriter.create(loc, type, arg); + promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg); } types.push_back(type); args.push_back(promotedArg); } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); - Value one = rewriter.create(loc, rewriter.getI64Type(), + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), rewriter.getIndexAttr(1)); Value tempAlloc = - rewriter.create(loc, ptrType, structType, one, + LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one, /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { - Value ptr = rewriter.create( + Value ptr = LLVM::GEPOp::create(rewriter, loc, ptrType, structType, tempAlloc, ArrayRef{0, static_cast(index)}); - rewriter.create(loc, arg, ptr); + LLVM::StoreOp::create(rewriter, loc, arg, ptr); } std::array printfArgs = {stringStart, tempAlloc}; - rewriter.create(loc, vprintfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -607,22 +607,22 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, TypeRange operandTypes(operands); VectorType vectorType = cast(llvm1DVectorTy); Location loc = op->getLoc(); - Value result = rewriter.create(loc, vectorType); + Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType); Type indexType = converter.convertType(rewriter.getIndexType()); StringAttr name = op->getName().getIdentifier(); Type elementType = vectorType.getElementType(); for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { - Value index = rewriter.create(loc, indexType, i); + Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i); auto extractElement = [&](Value operand) -> Value { if (!isa(operand.getType())) return operand; - return rewriter.create(loc, operand, index); + return LLVM::ExtractElementOp::create(rewriter, loc, operand, index); }; auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); - result = rewriter.create( + result = LLVM::InsertElementOp::create(rewriter, loc, result, scalarOp->getResult(0), index); } return result; @@ -705,7 +705,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol( auto zeroSizedArrayType = LLVM::LLVMArrayType::get( typeConverter->convertType(memrefType.getElementType()), 0); - return rewriter.create( + return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, addressSpace.value()); @@ -732,12 +732,12 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( // Step 3. Get address of the global symbol OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - auto basePtr = rewriter.create(loc, shmemOp); + auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp); Type baseType = basePtr->getResultTypes().front(); // Step 4. Generate GEP using offsets SmallVector gepArgs = {0}; - Value shmemPtr = rewriter.create(loc, baseType, elementType, + Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType, basePtr, gepArgs); // Step 5. Create a memref descriptor SmallVector shape, strides; @@ -799,9 +799,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite( return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 167cabbc57db9..5e8254a0b6035 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -79,7 +79,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern { uint64_t rank = type.getRank(); Value numElements = desc.size(rewriter, loc, /*pos=*/0); for (unsigned i = 1; i < rank; i++) - numElements = rewriter.create( + numElements = LLVM::MulOp::create(rewriter, loc, numElements, desc.size(rewriter, loc, /*pos=*/i)); return numElements; } @@ -582,7 +582,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, return OpBuilder::atBlockEnd(module.getBody()) .create(loc, functionName, functionType); }(); - return builder.create(loc, function, arguments); + return LLVM::CallOp::create(builder, loc, function, arguments); } // Corresponding to cusparseIndexType_t defined in cusparse.h. @@ -780,12 +780,12 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = rewriter.create(loc, llvmPointerType); + auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); - auto isHostShared = rewriter.create( + auto isHostShared = mlir::LLVM::ConstantOp::create(rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); Value allocatedPtr = @@ -1012,7 +1012,7 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast(bitwidth / 8) * static_cast(memrefTy.getNumElements()); - Value sizeArg = rewriter.create( + Value sizeArg = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); @@ -1025,7 +1025,7 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - rewriter.create( + gpu::LaunchFuncOp::create(rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, @@ -1048,7 +1048,7 @@ static Value bitAndAddrspaceCast(Location loc, const LLVMTypeConverter &typeConverter) { auto sourceTy = cast(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) - sourcePtr = rewriter.create( + sourcePtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), destinationType.getAddressSpace()), @@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType, typeConverter->convertType(memRefType.getElementType()), nullPtr, numElements); auto sizeBytes = - rewriter.create(loc, getIndexType(), gepPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcDesc.alignedPtr(rewriter, loc), @@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); auto value = - rewriter.create(loc, bitCastType, adaptor.getValue()); + LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue()); auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstDesc.alignedPtr(rewriter, loc), *getTypeConverter()); @@ -1150,14 +1150,14 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( template static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { Type llvmInt32Type = builder.getIntegerType(32); - return builder.create(loc, llvmInt32Type, + return LLVM::ConstantOp::create(builder, loc, llvmInt32Type, static_cast(tValue)); } template static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { Type llvmFloat32Type = builder.getF32Type(); - return builder.create( + return LLVM::ConstantOp::create(builder, loc, llvmFloat32Type, builder.getF32FloatAttr(static_cast(tValue))); } @@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( // the dnmat is used with spmat with 2:4 sparsity if (dims.size() == 2) { if (isSpMMCusparseLtOp(op.getDnTensor())) { - auto handleSz = rewriter.create( + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(11032)); - handle = rewriter.create( + handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create(loc, llvmPointerType, handle); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); createLtDnMatCallBuilder .create(loc, rewriter, @@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); // CUDA runner asserts the size is 44104 bytes. - auto handleSz = rewriter.create( + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(44104)); - Value handle = rewriter.create( + Value handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create(loc, llvmPointerType, handle); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); create2To4SpMatCallBuilder .create(loc, rewriter, @@ -1441,9 +1441,9 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); auto computeType = genConstInt32From( rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); - auto three = rewriter.create(loc, getIndexType(), + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(3)); - auto bufferSize = rewriter.create( + auto bufferSize = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); createCuSparseLtSpMMBufferSizeBuilder .create(loc, rewriter, @@ -1452,20 +1452,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( pruneFlag, stream}) .getResult(); - auto bufferSizePtr1 = rewriter.create( + auto bufferSizePtr1 = LLVM::GEPOp::create(rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create( + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(1))}); - auto bufferSizePtr2 = rewriter.create( + auto bufferSizePtr2 = LLVM::GEPOp::create(rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create( + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(2))}); auto bufferSize0 = - rewriter.create(loc, llvmInt64Type, bufferSize); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize); auto bufferSize1 = - rewriter.create(loc, llvmInt64Type, bufferSizePtr1); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1); auto bufferSize2 = - rewriter.create(loc, llvmInt64Type, bufferSizePtr2); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2); rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); } else { @@ -1669,28 +1669,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op.getLoc(); auto stream = adaptor.getAsyncDependencies().front(); - auto three = rewriter.create(loc, getIndexType(), + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(3)); - auto buffer = rewriter.create( + auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); - auto rowsPtr = rewriter.create( + auto rowsPtr = LLVM::GEPOp::create(rewriter, loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create(loc, getIndexType(), + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(0))}); - auto colsPtr = rewriter.create( + auto colsPtr = LLVM::GEPOp::create(rewriter, loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create(loc, getIndexType(), + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(1))}); - auto nnzsPtr = rewriter.create( + auto nnzsPtr = LLVM::GEPOp::create(rewriter, loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create(loc, getIndexType(), + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(2))}); createSpMatGetSizeBuilder.create( loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); - auto rows = rewriter.create(loc, llvmInt64Type, rowsPtr); - auto cols = rewriter.create(loc, llvmInt64Type, colsPtr); - auto nnzs = rewriter.create(loc, llvmInt64Type, nnzsPtr); + auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr); + auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr); + auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr); rewriter.replaceOp(op, {rows, cols, nnzs, stream}); return success(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index aab2409ed6328..70964727351a9 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -59,13 +59,13 @@ struct OpLowering : public ConvertOpToLLVMPattern { Operation *newOp; switch (op.getDimension()) { case gpu::Dimension::x: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::y: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::z: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32)); break; } @@ -124,10 +124,10 @@ struct OpLowering : public ConvertOpToLLVMPattern { rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { - newOp = rewriter.create( + newOp = LLVM::SExtOp::create(rewriter, loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); } else if (indexBitwidth < 32) { - newOp = rewriter.create( + newOp = LLVM::TruncOp::create(rewriter, loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); } diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 64cf09e600b88..33b09b06649b9 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -103,7 +103,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = - rewriter.create(op->getLoc(), funcOp, castedOperands); + LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands); if (resultType == adaptor.getOperands().front().getType()) { rewriter.replaceOp(op, {callOp.getResult()}); @@ -115,10 +115,10 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { // there is no guarantee of a specific value being used to indicate true, // compare for inequality with zero (rather than truncate or shift). if (isResultBool) { - Value zero = rewriter.create( + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0)); - Value truncated = rewriter.create( + Value truncated = LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero); rewriter.replaceOp(op, {truncated}); return success(); @@ -126,7 +126,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { assert(callOp.getResult().getType().isF32() && "only f32 types are supposed to be truncated back"); - Value truncated = rewriter.create( + Value truncated = LLVM::FPTruncOp::create(rewriter, op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); rewriter.replaceOp(op, {truncated}); @@ -142,7 +142,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { if (!f16Func.empty() && isa(type)) return operand; - return rewriter.create( + return LLVM::FPExtOp::create(rewriter, operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); } @@ -169,7 +169,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { // location as debug info metadata inside of a function cannot be used // outside of that function. auto globalloc = op->getLoc()->findInstanceOfOrUnknown(); - return b.create(globalloc, funcName, funcType); + return LLVMFuncOp::create(b, globalloc, funcName, funcType); } StringRef getFunctionName(Type type, SourceOp op) const { diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 8b6b553f6eed0..fc469102a54d2 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -54,7 +54,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, SymbolTable::lookupSymbolIn(symbolTable, name)); if (!func) { OpBuilder b(symbolTable->getRegion(0)); - func = b.create( + func = LLVM::LLVMFuncOp::create(b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); @@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = rewriter.create(loc, func, args); + auto call = LLVM::CallOp::create(rewriter, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern { constexpr int64_t localMemFenceFlag = 1; Location loc = op->getLoc(); Value flag = - rewriter.create(loc, flagTy, localMemFenceFlag); + LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag); rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); return success(); } @@ -162,7 +162,7 @@ struct LaunchConfigConversion : ConvertToLLVMPattern { Location loc = op->getLoc(); gpu::Dimension dim = getDimension(op); - Value dimVal = rewriter.create(loc, dimTy, + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, static_cast(dim)); rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); return success(); @@ -291,12 +291,12 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) { return TypeSwitch(oldVal.getType()) .Case([&](BFloat16Type) { - return rewriter.create(loc, rewriter.getI16Type(), + return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(), oldVal); }) .Case([&](IntegerType intTy) -> Value { if (intTy.getWidth() == 1) - return rewriter.create(loc, rewriter.getI8Type(), + return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(), oldVal); return oldVal; }) @@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) { return TypeSwitch(newTy) .Case([&](BFloat16Type) { - return rewriter.create(loc, newTy, oldVal); + return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal); }) .Case([&](IntegerType intTy) -> Value { if (intTy.getWidth() == 1) - return rewriter.create(loc, newTy, oldVal); + return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal); return oldVal; }) .Default(oldVal); @@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter); Value trueVal = - rewriter.create(loc, rewriter.getI1Type(), true); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true); rewriter.replaceOp(op, {resultOrConversion, trueVal}); return success(); } @@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern { if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { return failure(); } - result = rewriter.create(loc, indexTy, result); + result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 1ef6edea93c58..7a6601ea754e9 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -118,9 +118,9 @@ struct GPUSubgroupReduceOpLowering Location loc = op->getLoc(); auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value offset = rewriter.create(loc, int32Type, -1); + Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); - auto reduxOp = rewriter.create(loc, int32Type, op.getValue(), + auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type, op.getValue(), mode.value(), offset); rewriter.replaceOp(op, reduxOp->getResult(0)); @@ -158,13 +158,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { auto int32Type = IntegerType::get(rewriter.getContext(), 32); auto predTy = IntegerType::get(rewriter.getContext(), 1); - Value one = rewriter.create(loc, int32Type, 1); - Value minusOne = rewriter.create(loc, int32Type, -1); - Value thirtyTwo = rewriter.create(loc, int32Type, 32); - Value numLeadInactiveLane = rewriter.create( + Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); + Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32); + Value numLeadInactiveLane = LLVM::SubOp::create(rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth()); // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. - Value activeMask = rewriter.create(loc, int32Type, minusOne, + Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne, numLeadInactiveLane); Value maskAndClamp; if (op.getMode() == gpu::ShuffleMode::UP) { @@ -173,7 +173,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { } else { // Clamp lane: `activeWidth - 1` maskAndClamp = - rewriter.create(loc, int32Type, adaptor.getWidth(), one); + LLVM::SubOp::create(rewriter, loc, int32Type, adaptor.getWidth(), one); } bool predIsUsed = !op->getResult(1).use_empty(); @@ -184,13 +184,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueTy, predTy}); } - Value shfl = rewriter.create( + Value shfl = NVVM::ShflOp::create(rewriter, loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); if (predIsUsed) { - Value shflValue = rewriter.create(loc, shfl, 0); + Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0); Value isActiveSrcLane = - rewriter.create(loc, shfl, 1); + LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); } else { rewriter.replaceOp(op, {shfl, nullptr}); @@ -215,15 +215,15 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { bounds = rewriter.getAttr( /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); Value newOp = - rewriter.create(loc, rewriter.getI32Type(), bounds); + NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - newOp = rewriter.create( + newOp = LLVM::SExtOp::create(rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { - newOp = rewriter.create( + newOp = LLVM::TruncOp::create(rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); @@ -271,10 +271,10 @@ struct AssertOpToAssertfailLowering Block *afterBlock = rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); rewriter.setInsertionPointToEnd(beforeBlock); - rewriter.create(loc, adaptor.getArg(), afterBlock, + cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock, assertBlock); rewriter.setInsertionPointToEnd(assertBlock); - rewriter.create(loc, afterBlock); + cf::BranchOp::create(rewriter, loc, afterBlock); // Continue cf.assert lowering. rewriter.setInsertionPoint(assertOp); @@ -301,11 +301,11 @@ struct AssertOpToAssertfailLowering // Create constants. auto getGlobal = [&](LLVM::GlobalOp global) { // Get a pointer to the format string's first element. - Value globalPtr = rewriter.create( + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), global.getSymNameAttr()); Value start = - rewriter.create(loc, ptrType, global.getGlobalType(), + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); return start; }; @@ -316,8 +316,8 @@ struct AssertOpToAssertfailLowering Value assertFunc = getGlobal(getOrCreateStringConstant( rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); Value assertLine = - rewriter.create(loc, i32Type, fileLine); - Value c1 = rewriter.create(loc, i64Type, 1); + LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1); // Insert function call to __assertfail. SmallVector arguments{assertMessage, assertFile, assertLine, diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 45fd933d58857..7d5f65c470386 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -126,7 +126,7 @@ struct WmmaLoadOpToNVVMLowering cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), adaptor.getSrcMemref(), adaptor.getIndices()); - Value leadingDim = rewriter.create( + Value leadingDim = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); rewriter.replaceOpWithNewOp( @@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering auto matrixType = cast(adaptor.getSrc().getType()); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = - rewriter.create(loc, adaptor.getSrc(), i); + LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i); storeOpOperands.push_back(toUse); } @@ -181,7 +181,7 @@ struct WmmaStoreOpToNVVMLowering rewriter, loc, cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()), adaptor.getDstMemref(), adaptor.getIndices()); - Value leadingDim = rewriter.create( + Value leadingDim = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); rewriter.replaceOpWithNewOp( @@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering auto unpackOp = [&](Value operand) { auto structType = cast(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { - Value toUse = rewriter.create(loc, operand, i); + Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); unpackedOps.push_back(toUse); } }; @@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering cast(subgroupMmaConstantOp.getType())); // If the element type is a vector create a vector from the operand. if (auto vecType = dyn_cast(type.getBody()[0])) { - Value vecCst = rewriter.create(loc, vecType); + Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { - Value idx = rewriter.create( + Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), vecEl); - vecCst = rewriter.create(loc, vecType, vecCst, + vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst, cst, idx); } cst = vecCst; } - Value matrixStruct = rewriter.create(loc, type); + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { matrixStruct = - rewriter.create(loc, matrixStruct, cst, i); + LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); return success(); @@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Type i1Type = builder.getI1Type(); if (auto vecType = dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); - Value cmp = builder.create( + Value cmp = LLVM::FCmpOp::create(builder, loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs); - Value sel = builder.create(loc, cmp, lhs, rhs); - Value isNan = builder.create( + Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs); + Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); - Value nan = builder.create( + Value nan = LLVM::ConstantOp::create(builder, loc, lhs.getType(), builder.getFloatAttr(floatType, APFloat::getQNaN(floatType.getFloatSemantics()))); - return builder.create(loc, isNan, nan, sel); + return LLVM::SelectOp::create(builder, loc, isNan, nan, sel); } static Value createScalarOp(OpBuilder &builder, Location loc, @@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc, ArrayRef operands) { switch (op) { case gpu::MMAElementwiseOp::ADDF: - return builder.create(loc, operands[0].getType(), operands); + return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MULF: - return builder.create(loc, operands[0].getType(), operands); + return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::DIVF: - return builder.create(loc, operands[0].getType(), operands); + return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MAXF: return createMinMaxF(builder, loc, operands[0], operands[1], /*isMin=*/false); @@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering size_t numOperands = adaptor.getOperands().size(); LLVM::LLVMStructType destType = convertMMAToLLVMType( cast(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = rewriter.create(loc, destType); + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { SmallVector extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { - extractedOperands.push_back(rewriter.create( + extractedOperands.push_back(LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getOperands()[opIdx], i)); } Value element = createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), extractedOperands); matrixStruct = - rewriter.create(loc, matrixStruct, element, i); + LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i); } rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); return success(); diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 456bfaba980ca..9b43d3ef602d1 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth()); // TODO: use <=> in C++20. if (indexBitwidth > intWidth) { - return rewriter.create(loc, indexBitwidthType, value); + return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value); } if (indexBitwidth < intWidth) { - return rewriter.create(loc, indexBitwidthType, value); + return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value); } return value; } @@ -82,11 +82,11 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth) { auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value zero = rewriter.create(loc, 0, 32); - Value minus1 = rewriter.create(loc, -1, 32); - Value mbcntLo = rewriter.create(loc, int32Type, + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); + Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, ValueRange{minus1, zero}); - Value laneId = rewriter.create(loc, int32Type, + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, ValueRange{minus1, mbcntLo}); return laneId; } @@ -110,20 +110,20 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) Type intTy = IntegerType::get(context, 32); - Value zero = rewriter.create(loc, 0, 32); - Value minus1 = rewriter.create(loc, -1, 32); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); Value mbcntLo = - rewriter.create(loc, intTy, ValueRange{minus1, zero}); - Value laneId = rewriter.create( + ROCDL::MbcntLoOp::create(rewriter, loc, intTy, ValueRange{minus1, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy, ValueRange{minus1, mbcntLo}); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - laneId = rewriter.create( + laneId = LLVM::SExtOp::create(rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } else if (indexBitwidth < 32) { - laneId = rewriter.create( + laneId = LLVM::TruncOp::create(rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } rewriter.replaceOp(op, {laneId}); @@ -149,7 +149,7 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern { /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32, /*upper=*/op.getUpperBoundAttr().getInt() + 1); } - Value wavefrontOp = rewriter.create( + Value wavefrontOp = ROCDL::WavefrontSizeOp::create(rewriter, op.getLoc(), rewriter.getI32Type(), bounds); wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp, *getTypeConverter()); @@ -190,43 +190,43 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value width = adaptor.getWidth(); - Value zero = rewriter.create(loc, int32Type, 0); - Value negwidth = rewriter.create(loc, int32Type, zero, width); - Value add = rewriter.create(loc, int32Type, srcLaneId, width); + Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0); + Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width); + Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width); Value widthOrZeroIfOutside = - rewriter.create(loc, int32Type, add, negwidth); + LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth); Value dstLane; switch (op.getMode()) { case gpu::ShuffleMode::UP: - dstLane = rewriter.create(loc, int32Type, srcLaneId, + dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId, adaptor.getOffset()); break; case gpu::ShuffleMode::DOWN: - dstLane = rewriter.create(loc, int32Type, srcLaneId, + dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, adaptor.getOffset()); break; case gpu::ShuffleMode::XOR: - dstLane = rewriter.create(loc, int32Type, srcLaneId, + dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId, adaptor.getOffset()); break; case gpu::ShuffleMode::IDX: dstLane = adaptor.getOffset(); break; } - Value isActiveSrcLane = rewriter.create( + Value isActiveSrcLane = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); - Value selectDstLane = rewriter.create(loc, isActiveSrcLane, + Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane, dstLane, srcLaneId); - Value two = rewriter.create(loc, int32Type, 2); + Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2); Value dwordAlignedDstLane = - rewriter.create(loc, int32Type, selectDstLane, two); + LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two); SmallVector decomposed = LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type); SmallVector swizzled; for (Value v : decomposed) { - Value res = rewriter.create(loc, int32Type, + Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, dwordAlignedDstLane, v); swizzled.emplace_back(res); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index b99ed261ecfa3..71618cbf8a2f2 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion::matchAndRewrite( Value vector = spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); - Value dim = rewriter.create( + Value dim = spirv::CompositeExtractOp::create(rewriter, op.getLoc(), builtinType, vector, rewriter.getI32ArrayAttr({static_cast(op.getDimension())})); if (forShader && builtinType != indexType) - dim = rewriter.create(op.getLoc(), indexType, dim); + dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim); rewriter.replaceOp(op, dim); return success(); } @@ -198,7 +198,7 @@ SingleDimLaunchConfigConversion::matchAndRewrite( Value builtinValue = spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); if (i32Type != indexType) - builtinValue = rewriter.create(op.getLoc(), indexType, + builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, builtinValue); rewriter.replaceOp(op, builtinValue); return success(); @@ -257,7 +257,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, signatureConverter.addInputs(argType.index(), convertedType); } } - auto newFuncOp = rewriter.create( + auto newFuncOp = spirv::FuncOp::create(rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {})); for (const auto &namedAttr : funcOp->getAttrs()) { @@ -367,7 +367,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite( // Add a keyword to the module name to avoid symbolic conflict. std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); - auto spvModule = rewriter.create( + auto spvModule = spirv::ModuleOp::create(rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, StringRef(spvModuleName)); @@ -452,41 +452,41 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( switch (shuffleOp.getMode()) { case gpu::ShuffleMode::XOR: { - result = rewriter.create( + result = spirv::GroupNonUniformShuffleXorOp::create(rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::IDX: { - result = rewriter.create( + result = spirv::GroupNonUniformShuffleOp::create(rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::DOWN: { - result = rewriter.create( + result = spirv::GroupNonUniformShuffleDownOp::create(rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create(loc, laneId, adaptor.getOffset()); - validVal = rewriter.create(loc, arith::CmpIPredicate::ult, + arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset()); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, resultLaneId, adaptor.getWidth()); break; } case gpu::ShuffleMode::UP: { - result = rewriter.create( + result = spirv::GroupNonUniformShuffleUpOp::create(rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create(loc, laneId, adaptor.getOffset()); + arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset()); auto i32Type = rewriter.getIntegerType(32); - validVal = rewriter.create( + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, resultLaneId, - rewriter.create( + arith::ConstantOp::create(rewriter, loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); break; } @@ -516,14 +516,14 @@ LogicalResult GPURotateConversion::matchAndRewrite( Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr(spirv::Scope::Subgroup); - Value rotateResult = rewriter.create( + Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth()); Value validVal; if (widthAttr.getValue().getZExtValue() == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { - Value laneId = rewriter.create(loc, widthAttr); - validVal = rewriter.create(loc, arith::CmpIPredicate::ult, + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, laneId, adaptor.getWidth()); } @@ -548,13 +548,13 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, ? spirv::GroupOperation::ClusteredReduce : spirv::GroupOperation::Reduce); if (isUniform) { - return builder.create(loc, type, scope, groupOp, arg) + return UniformOp::create(builder, loc, type, scope, groupOp, arg) .getResult(); } Value clusterSizeValue; if (clusterSize.has_value()) - clusterSizeValue = builder.create( + clusterSizeValue = spirv::ConstantOp::create(builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); @@ -740,7 +740,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstName = makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc"); - return rewriter.create( + return spirv::SpecConstantOp::create(rewriter, loc, rewriter.getStringAttr(specCstName), attr); }; { @@ -774,7 +774,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstCompositeName = (llvm::Twine(globalVarName) + "_scc").str(); - specCstComposite = rewriter.create( + specCstComposite = spirv::SpecConstantCompositeOp::create(rewriter, loc, TypeAttr::get(globalType), rewriter.getStringAttr(specCstCompositeName), rewriter.getArrayAttr(constituents)); @@ -785,15 +785,15 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( // Define a GlobalVarOp initialized using specialized constants // that is used to specify the printf format string // to be passed to the SPIRV CLPrintfOp. - globalVar = rewriter.create( + globalVar = spirv::GlobalVariableOp::create(rewriter, loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite)); globalVar->setAttr("Constant", rewriter.getUnitAttr()); } // Get SSA value of Global variable and create pointer to i8 to point to // the format string. - Value globalPtr = rewriter.create(loc, globalVar); - Value fmtStr = rewriter.create( + Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar); + Value fmtStr = spirv::BitcastOp::create(rewriter, loc, spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant), globalPtr); @@ -801,7 +801,7 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( // Get printf arguments. auto printfArgs = llvm::to_vector_of(adaptor.getArgs()); - rewriter.create(loc, i32Type, fmtStr, printfArgs); + spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs); // Need to erase the gpu.printf op as gpu.printf does not use result vs // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index 0b2c06a08db2d..da8ffe47a7d33 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -144,11 +144,11 @@ void GPUToSPIRVPass::runOnOperation() { if (targetEnvSupportsKernelCapability(moduleOp)) { moduleOp.walk([&](gpu::GPUFuncOp funcOp) { builder.setInsertionPoint(funcOp); - auto newFuncOp = builder.create( + auto newFuncOp = func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); auto entryBlock = newFuncOp.addEntryBlock(); builder.setInsertionPointToEnd(entryBlock); - builder.create(funcOp.getLoc()); + func::ReturnOp::create(builder, funcOp.getLoc()); newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); funcOp.erase(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index 7bb86b5ce1ddd..70646aa5f2e74 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -283,7 +283,7 @@ struct WmmaLoadOpToSPIRVLowering final int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create( + auto strideValue = spirv::ConstantOp::create(rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); @@ -315,7 +315,7 @@ struct WmmaStoreOpToSPIRVLowering final int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create( + auto strideValue = spirv::ConstantOp::create(rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp index 0473bb59fa6aa..0d5b318f739cc 100644 --- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -36,34 +36,34 @@ struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create(loc, n.getType(), 0); - Value posOne = rewriter.create(loc, n.getType(), 1); - Value negOne = rewriter.create(loc, n.getType(), -1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); + Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1); // Compute `x`. Value mPos = - rewriter.create(loc, LLVM::ICmpPredicate::sgt, m, zero); - Value x = rewriter.create(loc, mPos, negOne, posOne); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, m, zero); + Value x = LLVM::SelectOp::create(rewriter, loc, mPos, negOne, posOne); // Compute the positive result. - Value nPlusX = rewriter.create(loc, n, x); - Value nPlusXDivM = rewriter.create(loc, nPlusX, m); - Value posRes = rewriter.create(loc, nPlusXDivM, posOne); + Value nPlusX = LLVM::AddOp::create(rewriter, loc, n, x); + Value nPlusXDivM = LLVM::SDivOp::create(rewriter, loc, nPlusX, m); + Value posRes = LLVM::AddOp::create(rewriter, loc, nPlusXDivM, posOne); // Compute the negative result. - Value negN = rewriter.create(loc, zero, n); - Value negNDivM = rewriter.create(loc, negN, m); - Value negRes = rewriter.create(loc, zero, negNDivM); + Value negN = LLVM::SubOp::create(rewriter, loc, zero, n); + Value negNDivM = LLVM::SDivOp::create(rewriter, loc, negN, m); + Value negRes = LLVM::SubOp::create(rewriter, loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. Value nPos = - rewriter.create(loc, LLVM::ICmpPredicate::sgt, n, zero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, n, zero); Value sameSign = - rewriter.create(loc, LLVM::ICmpPredicate::eq, nPos, mPos); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, nPos, mPos); Value nNonZero = - rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); - Value cmp = rewriter.create(loc, sameSign, nNonZero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = LLVM::AndOp::create(rewriter, loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } @@ -83,17 +83,17 @@ struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create(loc, n.getType(), 0); - Value one = rewriter.create(loc, n.getType(), 1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value one = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); // Compute the non-zero result. - Value minusOne = rewriter.create(loc, n, one); - Value quotient = rewriter.create(loc, minusOne, m); - Value plusOne = rewriter.create(loc, quotient, one); + Value minusOne = LLVM::SubOp::create(rewriter, loc, n, one); + Value quotient = LLVM::UDivOp::create(rewriter, loc, minusOne, m); + Value plusOne = LLVM::AddOp::create(rewriter, loc, quotient, one); // Pick the result. Value cmp = - rewriter.create(loc, LLVM::ICmpPredicate::eq, n, zero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, n, zero); rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); return success(); } @@ -114,32 +114,32 @@ struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create(loc, n.getType(), 0); - Value posOne = rewriter.create(loc, n.getType(), 1); - Value negOne = rewriter.create(loc, n.getType(), -1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); + Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1); // Compute `x`. Value mNeg = - rewriter.create(loc, LLVM::ICmpPredicate::slt, m, zero); - Value x = rewriter.create(loc, mNeg, posOne, negOne); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, m, zero); + Value x = LLVM::SelectOp::create(rewriter, loc, mNeg, posOne, negOne); // Compute the negative result. - Value xMinusN = rewriter.create(loc, x, n); - Value xMinusNDivM = rewriter.create(loc, xMinusN, m); - Value negRes = rewriter.create(loc, negOne, xMinusNDivM); + Value xMinusN = LLVM::SubOp::create(rewriter, loc, x, n); + Value xMinusNDivM = LLVM::SDivOp::create(rewriter, loc, xMinusN, m); + Value negRes = LLVM::SubOp::create(rewriter, loc, negOne, xMinusNDivM); // Compute the positive result. - Value posRes = rewriter.create(loc, n, m); + Value posRes = LLVM::SDivOp::create(rewriter, loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. Value nNeg = - rewriter.create(loc, LLVM::ICmpPredicate::slt, n, zero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, n, zero); Value diffSign = - rewriter.create(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); Value nNonZero = - rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); - Value cmp = rewriter.create(loc, diffSign, nNonZero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = LLVM::AndOp::create(rewriter, loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, negRes, posRes); return success(); } diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp index 4821962f989e6..6fc097059c5e5 100644 --- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -111,33 +111,33 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create( + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, 0)); - Value posOne = rewriter.create( + Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, 1)); - Value negOne = rewriter.create( + Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, -1)); // Compute `x`. - Value mPos = rewriter.create(loc, m, zero); - Value x = rewriter.create(loc, mPos, negOne, posOne); + Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero); + Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne); // Compute the positive result. - Value nPlusX = rewriter.create(loc, n, x); - Value nPlusXDivM = rewriter.create(loc, nPlusX, m); - Value posRes = rewriter.create(loc, nPlusXDivM, posOne); + Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x); + Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m); + Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne); // Compute the negative result. - Value negN = rewriter.create(loc, zero, n); - Value negNDivM = rewriter.create(loc, negN, m); - Value negRes = rewriter.create(loc, zero, negNDivM); + Value negN = spirv::ISubOp::create(rewriter, loc, zero, n); + Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m); + Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. - Value nPos = rewriter.create(loc, n, zero); - Value sameSign = rewriter.create(loc, nPos, mPos); - Value nNonZero = rewriter.create(loc, n, zero); - Value cmp = rewriter.create(loc, sameSign, nNonZero); + Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero); + Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos); + Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero); + Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } @@ -161,18 +161,18 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create( + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, 0)); - Value one = rewriter.create(loc, n_type, + Value one = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, 1)); // Compute the non-zero result. - Value minusOne = rewriter.create(loc, n, one); - Value quotient = rewriter.create(loc, minusOne, m); - Value plusOne = rewriter.create(loc, quotient, one); + Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one); + Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m); + Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one); // Pick the result - Value cmp = rewriter.create(loc, n, zero); + Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero); rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); return success(); } @@ -197,32 +197,32 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create( + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, 0)); - Value posOne = rewriter.create( + Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, 1)); - Value negOne = rewriter.create( + Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, IntegerAttr::get(n_type, -1)); // Compute `x`. - Value mNeg = rewriter.create(loc, m, zero); - Value x = rewriter.create(loc, mNeg, posOne, negOne); + Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero); + Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne); // Compute the negative result - Value xMinusN = rewriter.create(loc, x, n); - Value xMinusNDivM = rewriter.create(loc, xMinusN, m); - Value negRes = rewriter.create(loc, negOne, xMinusNDivM); + Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n); + Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m); + Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM); // Compute the positive result. - Value posRes = rewriter.create(loc, n, m); + Value posRes = spirv::SDivOp::create(rewriter, loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. - Value nNeg = rewriter.create(loc, n, zero); - Value diffSign = rewriter.create(loc, nNeg, mNeg); - Value nNonZero = rewriter.create(loc, n, zero); + Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero); + Value diffSign = spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg); + Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero); - Value cmp = rewriter.create(loc, diffSign, nNonZero); + Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index e34d5f74d232f..e8ad0d4644389 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -32,7 +32,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor) MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = builder.create(loc, descriptorType); + Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType); return MemRefDescriptor(descriptor); } @@ -99,20 +99,20 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { - return builder.create(loc, resultType, + return LLVM::ConstantOp::create(builder, loc, resultType, builder.getIndexAttr(value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { - return builder.create(loc, value, + return LLVM::ExtractValueOp::create(builder, loc, value, kOffsetPosInMemRefDescriptor); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { - value = builder.create(loc, value, offset, + value = LLVM::InsertValueOp::create(builder, loc, value, offset, kOffsetPosInMemRefDescriptor); } @@ -125,7 +125,7 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( + return LLVM::ExtractValueOp::create(builder, loc, value, ArrayRef({kSizePosInMemRefDescriptor, pos})); } @@ -137,22 +137,22 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, // Copy size values to stack-allocated memory. auto one = createIndexAttrConstant(builder, loc, indexType, 1); - auto sizes = builder.create( + auto sizes = LLVM::ExtractValueOp::create(builder, loc, value, llvm::ArrayRef({kSizePosInMemRefDescriptor})); - auto sizesPtr = builder.create(loc, ptrTy, arrayTy, one, + auto sizesPtr = LLVM::AllocaOp::create(builder, loc, ptrTy, arrayTy, one, /*alignment=*/0); - builder.create(loc, sizes, sizesPtr); + LLVM::StoreOp::create(builder, loc, sizes, sizesPtr); // Load an return size value of interest. - auto resultPtr = builder.create(loc, ptrTy, arrayTy, sizesPtr, + auto resultPtr = LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, sizesPtr, ArrayRef{0, pos}); - return builder.create(loc, indexType, resultPtr); + return LLVM::LoadOp::create(builder, loc, indexType, resultPtr); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { - value = builder.create( + value = LLVM::InsertValueOp::create(builder, loc, value, size, ArrayRef({kSizePosInMemRefDescriptor, pos})); } @@ -164,14 +164,14 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( + return LLVM::ExtractValueOp::create(builder, loc, value, ArrayRef({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { - value = builder.create( + value = LLVM::InsertValueOp::create(builder, loc, value, stride, ArrayRef({kStridePosInMemRefDescriptor, pos})); } @@ -207,7 +207,7 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, ? offset(builder, loc) : createIndexAttrConstant(builder, loc, indexType, offsetCst); Type elementType = converter.convertType(type.getElementType()); - ptr = builder.create(loc, ptr.getType(), elementType, ptr, + ptr = LLVM::GEPOp::create(builder, loc, ptr.getType(), elementType, ptr, offsetVal); return ptr; } @@ -303,7 +303,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = builder.create(loc, descriptorType); + Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { @@ -380,18 +380,18 @@ void UnrankedMemRefDescriptor::computeSizes( builder, loc, indexType, llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); Value doublePointerSize = - builder.create(loc, indexType, two, pointerSize); + LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); // (1 + 2 * rank) * sizeof(index) Value rank = desc.rank(builder, loc); - Value doubleRank = builder.create(loc, indexType, two, rank); + Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); Value doubleRankIncremented = - builder.create(loc, indexType, doubleRank, one); - Value rankIndexSize = builder.create( + LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); + Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, doubleRankIncremented, indexSize); // Total allocation size. - Value allocationSize = builder.create( + Value allocationSize = LLVM::AddOp::create(builder, loc, indexType, doublePointerSize, rankIndexSize); sizes.push_back(allocationSize); } @@ -400,13 +400,13 @@ void UnrankedMemRefDescriptor::computeSizes( Value UnrankedMemRefDescriptor::allocatedPtr( OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { - return builder.create(loc, elemPtrType, memRefDescPtr); + return LLVM::LoadOp::create(builder, loc, elemPtrType, memRefDescPtr); } void UnrankedMemRefDescriptor::setAllocatedPtr( OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { - builder.create(loc, allocatedPtr, memRefDescPtr); + LLVM::StoreOp::create(builder, loc, allocatedPtr, memRefDescPtr); } static std::pair @@ -423,9 +423,9 @@ Value UnrankedMemRefDescriptor::alignedPtr( castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); Value alignedGep = - builder.create(loc, elemPtrPtrType, elemPtrType, + LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, elementPtrPtr, ArrayRef{1}); - return builder.create(loc, elemPtrType, alignedGep); + return LLVM::LoadOp::create(builder, loc, elemPtrType, alignedGep); } void UnrankedMemRefDescriptor::setAlignedPtr( @@ -435,9 +435,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr( castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); Value alignedGep = - builder.create(loc, elemPtrPtrType, elemPtrType, + LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, elementPtrPtr, ArrayRef{1}); - builder.create(loc, alignedPtr, alignedGep); + LLVM::StoreOp::create(builder, loc, alignedPtr, alignedGep); } Value UnrankedMemRefDescriptor::offsetBasePtr( @@ -446,7 +446,7 @@ Value UnrankedMemRefDescriptor::offsetBasePtr( auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); - return builder.create(loc, elemPtrPtrType, elemPtrType, + return LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, elementPtrPtr, ArrayRef{2}); } @@ -456,7 +456,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, LLVM::LLVMPointerType elemPtrType) { Value offsetPtr = offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); - return builder.create(loc, typeConverter.getIndexType(), + return LLVM::LoadOp::create(builder, loc, typeConverter.getIndexType(), offsetPtr); } @@ -467,7 +467,7 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { Value offsetPtr = offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); - builder.create(loc, offset, offsetPtr); + LLVM::StoreOp::create(builder, loc, offset, offsetPtr); } Value UnrankedMemRefDescriptor::sizeBasePtr( @@ -477,7 +477,7 @@ Value UnrankedMemRefDescriptor::sizeBasePtr( Type structTy = LLVM::LLVMStructType::getLiteral( indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); auto resultType = LLVM::LLVMPointerType::get(builder.getContext()); - return builder.create(loc, resultType, structTy, memRefDescPtr, + return LLVM::GEPOp::create(builder, loc, resultType, structTy, memRefDescPtr, ArrayRef{0, 3}); } @@ -489,8 +489,8 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value sizeStoreGep = - builder.create(loc, ptrType, indexTy, sizeBasePtr, index); - return builder.create(loc, indexTy, sizeStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index); + return LLVM::LoadOp::create(builder, loc, indexTy, sizeStoreGep); } void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, @@ -501,8 +501,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value sizeStoreGep = - builder.create(loc, ptrType, indexTy, sizeBasePtr, index); - builder.create(loc, size, sizeStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index); + LLVM::StoreOp::create(builder, loc, size, sizeStoreGep); } Value UnrankedMemRefDescriptor::strideBasePtr( @@ -511,7 +511,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr( Type indexTy = typeConverter.getIndexType(); auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - return builder.create(loc, ptrType, indexTy, sizeBasePtr, rank); + return LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, rank); } Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, @@ -522,8 +522,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value strideStoreGep = - builder.create(loc, ptrType, indexTy, strideBasePtr, index); - return builder.create(loc, indexTy, strideStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index); + return LLVM::LoadOp::create(builder, loc, indexTy, strideStoreGep); } void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, @@ -534,6 +534,6 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value strideStoreGep = - builder.create(loc, ptrType, indexTy, strideBasePtr, index); - builder.create(loc, stride, strideStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index); + LLVM::StoreOp::create(builder, loc, stride, strideStoreGep); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index c5f72f7e10b8c..57aed79874ad0 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -57,7 +57,7 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { - return builder.create(loc, resultType, + return LLVM::ConstantOp::create(builder, loc, resultType, builder.getIndexAttr(value)); } @@ -123,7 +123,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( runningStride = sizes[i]; else if (stride == ShapedType::kDynamic) runningStride = - rewriter.create(loc, runningStride, sizes[i]); + LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]); else runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride); } @@ -131,10 +131,10 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( // Buffer size in bytes. Type elementType = typeConverter->convertType(memRefType.getElementType()); auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType, elementType, nullPtr, runningStride); - size = rewriter.create(loc, getIndexType(), gepPtr); + size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); } else { size = runningStride; } @@ -149,10 +149,10 @@ Value ConvertToLLVMPattern::getSizeInBytes( // which is a common pattern of getting the size of a type in bytes. Type llvmType = typeConverter->convertType(type); auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto nullPtr = rewriter.create(loc, convertedPtrType); - auto gep = rewriter.create(loc, convertedPtrType, llvmType, + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType); + auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType, nullPtr, ArrayRef{1}); - return rewriter.create(loc, getIndexType(), gep); + return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep); } Value ConvertToLLVMPattern::getNumElements( @@ -175,7 +175,7 @@ Value ConvertToLLVMPattern::getNumElements( staticSize == ShapedType::kDynamic ? dynamicSizes[dynamicIndex++] : createIndexAttrConstant(rewriter, loc, indexType, staticSize); - numElements = rewriter.create(loc, numElements, size); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } else { numElements = staticSize == ShapedType::kDynamic @@ -276,14 +276,14 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( ? builder .create(loc, mallocFunc.value(), allocationSize) .getResult() - : builder.create(loc, getPtrType(), + : LLVM::AllocaOp::create(builder, loc, getPtrType(), IntegerType::get(getContext(), 8), allocationSize, /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); - builder.create(loc, memory, source, allocationSize, false); + LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) - builder.create(loc, freeFunc.value(), source); + LLVM::CallOp::create(builder, loc, freeFunc.value(), source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks @@ -349,7 +349,7 @@ LogicalResult LLVM::detail::oneToOneRewrite( SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( + results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(), newOp->getResult(0), i)); } rewriter.replaceOp(op, results); @@ -371,7 +371,7 @@ LogicalResult LLVM::detail::intrinsicRewrite( if (numResults != 0) resType = typeConverter.packOperationResults(op->getResultTypes()); - auto callIntrOp = rewriter.create( + auto callIntrOp = LLVM::CallIntrinsicOp::create(rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands); // Propagate attributes. callIntrOp->setAttrs(op->getAttrDictionary()); @@ -388,7 +388,7 @@ LogicalResult LLVM::detail::intrinsicRewrite( results.reserve(numResults); Value intrRes = callIntrOp.getResults(); for (unsigned i = 0; i < numResults; ++i) - results.push_back(rewriter.create(loc, intrRes, i)); + results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i)); rewriter.replaceOp(op, results); return success(); @@ -406,7 +406,7 @@ static unsigned getBitWidth(Type type) { static Value createI32Constant(OpBuilder &builder, Location loc, int32_t value) { Type i32 = builder.getI32Type(); - return builder.create(loc, i32, value); + return LLVM::ConstantOp::create(builder, loc, i32, value); } SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, @@ -418,17 +418,17 @@ SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, unsigned srcBitWidth = getBitWidth(srcType); unsigned dstBitWidth = getBitWidth(dstType); if (srcBitWidth == dstBitWidth) { - Value cast = builder.create(loc, dstType, src); + Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src); return {cast}; } if (dstBitWidth > srcBitWidth) { auto smallerInt = builder.getIntegerType(srcBitWidth); if (srcType != smallerInt) - src = builder.create(loc, smallerInt, src); + src = LLVM::BitcastOp::create(builder, loc, smallerInt, src); auto largerInt = builder.getIntegerType(dstBitWidth); - Value res = builder.create(loc, largerInt, src); + Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src); return {res}; } assert(srcBitWidth % dstBitWidth == 0 && @@ -436,12 +436,12 @@ SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, int64_t numElements = srcBitWidth / dstBitWidth; auto vecType = VectorType::get(numElements, dstType); - src = builder.create(loc, vecType, src); + src = LLVM::BitcastOp::create(builder, loc, vecType, src); SmallVector res; for (auto i : llvm::seq(numElements)) { Value idx = createI32Constant(builder, loc, i); - Value elem = builder.create(loc, src, idx); + Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx); res.emplace_back(elem); } @@ -461,28 +461,28 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, if (dstBitWidth < srcBitWidth) { auto largerInt = builder.getIntegerType(srcBitWidth); if (res.getType() != largerInt) - res = builder.create(loc, largerInt, res); + res = LLVM::BitcastOp::create(builder, loc, largerInt, res); auto smallerInt = builder.getIntegerType(dstBitWidth); - res = builder.create(loc, smallerInt, res); + res = LLVM::TruncOp::create(builder, loc, smallerInt, res); } if (res.getType() != dstType) - res = builder.create(loc, dstType, res); + res = LLVM::BitcastOp::create(builder, loc, dstType, res); return res; } int64_t numElements = src.size(); auto srcType = VectorType::get(numElements, src.front().getType()); - Value res = builder.create(loc, srcType); + Value res = LLVM::PoisonOp::create(builder, loc, srcType); for (auto &&[i, elem] : llvm::enumerate(src)) { Value idx = createI32Constant(builder, loc, i); - res = builder.create(loc, srcType, res, elem, idx); + res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx); } if (res.getType() != dstType) - res = builder.create(loc, dstType, res); + res = LLVM::BitcastOp::create(builder, loc, dstType, res); return res; } @@ -518,18 +518,18 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, Value stride = ShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(builder, loc, i) - : builder.create( + : LLVM::ConstantOp::create(builder, loc, indexType, builder.getIndexAttr(strides[i])); increment = - builder.create(loc, increment, stride, intOverflowFlags); + LLVM::MulOp::create(builder, loc, increment, stride, intOverflowFlags); } - index = index ? builder.create(loc, index, increment, + index = index ? LLVM::AddOp::create(builder, loc, index, increment, intOverflowFlags) : increment; } Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? builder.create( + return index ? LLVM::GEPOp::create(builder, loc, elementPtrType, converter.convertType(type.getElementType()), base, index, noWrapFlags) diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 49c73fbc9dd79..091fb7e4c2afb 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -66,7 +66,7 @@ LogicalResult mlir::LLVM::createPrintStrCall( DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); auto arrayTy = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); - auto globalOp = builder.create( + auto globalOp = LLVM::GlobalOp::create(builder, loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr); @@ -74,15 +74,15 @@ LogicalResult mlir::LLVM::createPrintStrCall( // Emit call to `printStr` in runtime library. builder.restoreInsertionPoint(ip); auto msgAddr = - builder.create(loc, ptrTy, globalOp.getName()); + LLVM::AddressOfOp::create(builder, loc, ptrTy, globalOp.getName()); SmallVector indices(1, 0); Value gep = - builder.create(loc, ptrTy, arrayTy, msgAddr, indices); + LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, msgAddr, indices); FailureOr printer = LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName); if (failed(printer)) return failure(); - builder.create(loc, TypeRange(), + LLVM::CallOp::create(builder, loc, TypeRange(), SymbolRefAttr::get(printer.value()), gep); return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp index 1cd0bd85f9894..13ed4628c3c9e 100644 --- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp @@ -24,10 +24,10 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) { Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) const { - return builder.create(loc, value, pos); + return LLVM::ExtractValueOp::create(builder, loc, value, pos); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { - value = builder.create(loc, value, ptr, pos); + value = LLVM::InsertValueOp::create(builder, loc, value, ptr, pos); } diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 7312594c761f7..77135188dabcf 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -91,7 +91,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder, packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); if (!packed) return Value(); - return builder.create(loc, resultType, packed) + return UnrealizedConversionCastOp::create(builder, loc, resultType, packed) .getResult(0); } @@ -107,7 +107,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder, packRankedMemRefDesc(builder, resultType, inputs, loc, converter); if (!packed) return Value(); - return builder.create(loc, resultType, packed) + return UnrealizedConversionCastOp::create(builder, loc, resultType, packed) .getResult(0); } @@ -224,12 +224,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); @@ -731,12 +731,12 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - Value one = builder.create(loc, builder.getI64Type(), + Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), builder.getIndexAttr(1)); Value allocated = - builder.create(loc, ptrType, operand.getType(), one); + LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one); // Store into the alloca'ed descriptor. - builder.create(loc, operand, allocated); + LLVM::StoreOp::create(builder, loc, operand, allocated); return allocated; } diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index bf3f31729c3da..4fc208e05eb9d 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -87,17 +87,17 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; auto loc = op->getLoc(); - Value desc = rewriter.create(loc, resultNDVectoryTy); + Value desc = LLVM::PoisonOp::create(rewriter, loc, resultNDVectoryTy); nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (const auto &operand : llvm::enumerate(operands)) { - extractedOperands.push_back(rewriter.create( + extractedOperands.push_back(LLVM::ExtractValueOp::create(rewriter, loc, operand.value(), position)); } Value newVal = createOperand(result1DVectorTy, extractedOperands); - desc = rewriter.create(loc, desc, newVal, position); + desc = LLVM::InsertValueOp::create(rewriter, loc, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index c3f213147b7a7..498a5cfc40204 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -78,7 +78,7 @@ getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) { // Insert before module terminator. rewriter.setInsertionPoint(module.getBody(), std::prev(module.getBody()->end())); - func::FuncOp funcOp = rewriter.create( + func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(), fnNameAttr.getValue(), libFnType); // Insert a function attribute that will trigger the emission of the // corresponding `_mlir_ciface_xxx` interface so that external libraries see @@ -101,7 +101,7 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, continue; } Value cast = - b.create(loc, makeStridedLayoutDynamic(memrefType), op); + memref::CastOp::create(b, loc, makeStridedLayoutDynamic(memrefType), op); res.push_back(cast); } return res; diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index d4deff5b88070..30521d66a2e4c 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -54,18 +54,18 @@ std::pair getRawPtrAndSize(const Location loc, Value memRef, Type elType) { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); Value dataPtr = - rewriter.create(loc, ptrType, memRef, 1); - Value offset = rewriter.create( + LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1); + Value offset = LLVM::ExtractValueOp::create(rewriter, loc, rewriter.getI64Type(), memRef, 2); Value resPtr = - rewriter.create(loc, ptrType, elType, dataPtr, offset); + LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset); Value size; if (cast(memRef.getType()).getBody().size() > 3) { - size = rewriter.create(loc, memRef, + size = LLVM::ExtractValueOp::create(rewriter, loc, memRef, ArrayRef{3, 0}); - size = rewriter.create(loc, rewriter.getI32Type(), size); + size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size); } else { - size = rewriter.create(loc, 1, 32); + size = arith::ConstantIntOp::create(rewriter, loc, 1, 32); } return {resPtr, size}; } @@ -157,13 +157,13 @@ class MPICHImplTraits : public MPIImplTraits { Value getCommWorld(const Location loc, ConversionPatternRewriter &rewriter) override { static constexpr int MPI_COMM_WORLD = 0x44000000; - return rewriter.create(loc, rewriter.getI64Type(), + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), MPI_COMM_WORLD); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create(loc, rewriter.getI32Type(), comm); + return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm); } intptr_t getStatusIgnore() override { return 1; } @@ -195,7 +195,7 @@ class MPICHImplTraits : public MPIImplTraits { mtype = MPI_UINT8_T; else assert(false && "unsupported type"); - return rewriter.create(loc, rewriter.getI32Type(), mtype); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), mtype); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -245,7 +245,7 @@ class MPICHImplTraits : public MPIImplTraits { op = MPI_REPLACE; break; } - return rewriter.create(loc, rewriter.getI32Type(), op); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op); } }; @@ -281,15 +281,15 @@ class OMPIImplTraits : public MPIImplTraits { getOrDefineExternalStruct(loc, rewriter, name, commStructT); // get address of symbol - auto comm = rewriter.create( + auto comm = LLVM::AddressOfOp::create(rewriter, loc, LLVM::LLVMPointerType::get(context), SymbolRefAttr::get(context, name)); - return rewriter.create(loc, rewriter.getI64Type(), comm); + return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create( + return LLVM::IntToPtrOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); } @@ -330,7 +330,7 @@ class OMPIImplTraits : public MPIImplTraits { // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT); // get address of symbol - return rewriter.create( + return LLVM::AddressOfOp::create(rewriter, loc, LLVM::LLVMPointerType::get(context), SymbolRefAttr::get(context, mtype)); } @@ -389,7 +389,7 @@ class OMPIImplTraits : public MPIImplTraits { // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, op, opStructT); // get address of symbol - return rewriter.create( + return LLVM::AddressOfOp::create(rewriter, loc, LLVM::LLVMPointerType::get(context), SymbolRefAttr::get(context, op)); } @@ -424,7 +424,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` - auto nullPtrOp = rewriter.create(loc, ptrType); + auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType); Value llvmnull = nullPtrOp.getRes(); // grab a reference to the global module op: @@ -513,9 +513,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern { // get communicator Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - auto one = rewriter.create(loc, i32, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); auto outPtr = - rewriter.create(loc, ptrType, comm.getType(), one); + LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one); // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) auto funcType = @@ -524,14 +524,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern { LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Comm_split", funcType); - auto callOp = rewriter.create( + auto callOp = LLVM::CallOp::create(rewriter, loc, funcDecl, ValueRange{comm, adaptor.getColor(), adaptor.getKey(), outPtr.getRes()}); // load the communicator into a register - Value res = rewriter.create(loc, i32, outPtr.getResult()); - res = rewriter.create(loc, rewriter.getI64Type(), res); + Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult()); + res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res); // if retval is checked, replace uses of retval with the results from the // call op @@ -580,14 +580,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern { moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType); // replace with function call - auto one = rewriter.create(loc, i32, 1); - auto rankptr = rewriter.create(loc, ptrType, i32, one); - auto callOp = rewriter.create( + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); + auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one); + auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl, ValueRange{comm, rankptr.getRes()}); // load the rank into a register auto loadedRank = - rewriter.create(loc, i32, rankptr.getResult()); + LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult()); // if retval is checked, replace uses of retval with the results from the // call op @@ -641,7 +641,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType); // replace op with function call - auto funcCall = rewriter.create( + auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), comm}); @@ -683,10 +683,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern { auto mpiTraits = MPIImplTraits::get(moduleOp); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - Value statusIgnore = rewriter.create( + Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64, mpiTraits->getStatusIgnore()); statusIgnore = - rewriter.create(loc, ptrType, statusIgnore); + LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore); // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst, // tag, comm)` @@ -698,7 +698,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType); // replace op with function call - auto funcCall = rewriter.create( + auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getSource(), adaptor.getTag(), comm, statusIgnore}); @@ -738,9 +738,9 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { // If input and output are the same, request in-place operation. if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { - sendPtr = rewriter.create( + sendPtr = LLVM::ConstantOp::create(rewriter, loc, i64, reinterpret_cast(mpiTraits->getInPlace())); - sendPtr = rewriter.create(loc, ptrType, sendPtr); + sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr); } Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); @@ -757,7 +757,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType); // replace op with function call - auto funcCall = rewriter.create( + auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl, ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld}); diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 7f4655e53609e..bb3511440e041 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -121,7 +121,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { initValueAttr = FloatAttr::get(resultElementType, 0.0); else initValueAttr = IntegerAttr::get(resultElementType, 0); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr)); SmallVector strides = computeStrides(shape); for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { @@ -129,11 +129,11 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { SmallVector operands; for (Value input : op->getOperands()) operands.push_back( - rewriter.create(loc, input, positions)); + vector::ExtractOp::create(rewriter, loc, input, positions)); Value scalarOp = - rewriter.create(loc, vecType.getElementType(), operands); + Op::create(rewriter, loc, vecType.getElementType(), operands); result = - rewriter.create(loc, scalarOp, result, positions); + vector::InsertOp::create(rewriter, loc, scalarOp, result, positions); } rewriter.replaceOp(op, result); return success(); @@ -195,7 +195,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { FunctionType funcType = FunctionType::get( builder.getContext(), {elementType, elementType}, elementType); - auto funcOp = builder.create(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); @@ -208,11 +208,11 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); - Value zeroValue = builder.create( + Value zeroValue = arith::ConstantOp::create(builder, elementType, builder.getIntegerAttr(elementType, 0)); - Value oneValue = builder.create( + Value oneValue = arith::ConstantOp::create(builder, elementType, builder.getIntegerAttr(elementType, 1)); - Value minusOneValue = builder.create( + Value minusOneValue = arith::ConstantOp::create(builder, elementType, builder.getIntegerAttr(elementType, APInt(elementType.getIntOrFloatBitWidth(), -1ULL, @@ -221,81 +221,81 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { // if (p == T(0)) // return T(1); auto pIsZero = - builder.create(arith::CmpIPredicate::eq, pArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue); Block *thenBlock = builder.createBlock(funcBody); - builder.create(oneValue); + func::ReturnOp::create(builder, oneValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(pIsZero->getBlock()); - builder.create(pIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock); // if (p < T(0)) { builder.setInsertionPointToEnd(fallthroughBlock); auto pIsNeg = - builder.create(arith::CmpIPredicate::sle, pArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg, zeroValue); // if (b == T(0)) builder.createBlock(funcBody); auto bIsZero = - builder.create(arith::CmpIPredicate::eq, bArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue); // return T(1) / T(0); thenBlock = builder.createBlock(funcBody); - builder.create( - builder.create(oneValue, zeroValue).getResult()); + func::ReturnOp::create(builder, + arith::DivSIOp::create(builder, oneValue, zeroValue).getResult()); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(0)). builder.setInsertionPointToEnd(bIsZero->getBlock()); - builder.create(bIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock); // if (b == T(1)) builder.setInsertionPointToEnd(fallthroughBlock); auto bIsOne = - builder.create(arith::CmpIPredicate::eq, bArg, oneValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue); // return T(1); thenBlock = builder.createBlock(funcBody); - builder.create(oneValue); + func::ReturnOp::create(builder, oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(1)). builder.setInsertionPointToEnd(bIsOne->getBlock()); - builder.create(bIsOne, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock); // if (b == T(-1)) { builder.setInsertionPointToEnd(fallthroughBlock); - auto bIsMinusOne = builder.create(arith::CmpIPredicate::eq, + auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, minusOneValue); // if (p & T(1)) builder.createBlock(funcBody); - auto pIsOdd = builder.create( - arith::CmpIPredicate::ne, builder.create(pArg, oneValue), + auto pIsOdd = arith::CmpIOp::create(builder, + arith::CmpIPredicate::ne, arith::AndIOp::create(builder, pArg, oneValue), zeroValue); // return T(-1); thenBlock = builder.createBlock(funcBody); - builder.create(minusOneValue); + func::ReturnOp::create(builder, minusOneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(pIsOdd->getBlock()); - builder.create(pIsOdd, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock); // return T(1); // } // b == T(-1) builder.setInsertionPointToEnd(fallthroughBlock); - builder.create(oneValue); + func::ReturnOp::create(builder, oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(-1)). builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); - builder.create(bIsMinusOne, pIsOdd->getBlock(), + cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(), fallthroughBlock); // return T(0); // } // (p < T(0)) builder.setInsertionPointToEnd(fallthroughBlock); - builder.create(zeroValue); + func::ReturnOp::create(builder, zeroValue); Block *loopHeader = builder.createBlock( funcBody, funcBody->end(), {elementType, elementType, elementType}, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set up conditional branch for (p < T(0)). builder.setInsertionPointToEnd(pIsNeg->getBlock()); // Set initial values of 'result', 'b' and 'p' for the loop. - builder.create(pIsNeg, bIsZero->getBlock(), loopHeader, + cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader, ValueRange{oneValue, bArg, pArg}); // T result = T(1); @@ -313,44 +313,44 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { builder.setInsertionPointToEnd(loopHeader); // if (p & T(1)) - auto powerTmpIsOdd = builder.create( + auto powerTmpIsOdd = arith::CmpIOp::create(builder, arith::CmpIPredicate::ne, - builder.create(powerTmp, oneValue), zeroValue); + arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue); thenBlock = builder.createBlock(funcBody); // result *= b; - Value newResultTmp = builder.create(resultTmp, baseTmp); + Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); - builder.create(newResultTmp, fallthroughBlock); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); - builder.create(powerTmpIsOdd, thenBlock, fallthroughBlock, + cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock, resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= T(1); builder.setInsertionPointToEnd(fallthroughBlock); - Value newPowerTmp = builder.create(powerTmp, oneValue); + Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue); // if (p == T(0)) - auto newPowerIsZero = builder.create(arith::CmpIPredicate::eq, + auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, newPowerTmp, zeroValue); // return result; thenBlock = builder.createBlock(funcBody); - builder.create(newResultTmp); + func::ReturnOp::create(builder, newResultTmp); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); - builder.create(newPowerIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock, fallthroughBlock); // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); - Value newBaseTmp = builder.create(baseTmp, baseTmp); + Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. - builder.create( + cf::BranchOp::create(builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); return funcOp; } @@ -420,7 +420,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, llvm::raw_string_ostream nameOS(funcName); nameOS << '_' << baseType; nameOS << '_' << powType; - auto funcOp = builder.create(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); @@ -433,46 +433,46 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); - Value oneBValue = builder.create( + Value oneBValue = arith::ConstantOp::create(builder, baseType, builder.getFloatAttr(baseType, 1.0)); - Value zeroPValue = builder.create( + Value zeroPValue = arith::ConstantOp::create(builder, powType, builder.getIntegerAttr(powType, 0)); - Value onePValue = builder.create( + Value onePValue = arith::ConstantOp::create(builder, powType, builder.getIntegerAttr(powType, 1)); - Value minPValue = builder.create( + Value minPValue = arith::ConstantOp::create(builder, powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue( powType.getWidth()))); - Value maxPValue = builder.create( + Value maxPValue = arith::ConstantOp::create(builder, powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue( powType.getWidth()))); // if (p == Tp{0}) // return Tb{1}; auto pIsZero = - builder.create(arith::CmpIPredicate::eq, pArg, zeroPValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroPValue); Block *thenBlock = builder.createBlock(funcBody); - builder.create(oneBValue); + func::ReturnOp::create(builder, oneBValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == Tp{0}). builder.setInsertionPointToEnd(pIsZero->getBlock()); - builder.create(pIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock); builder.setInsertionPointToEnd(fallthroughBlock); // bool isNegativePower{p < Tp{0}} - auto pIsNeg = builder.create(arith::CmpIPredicate::sle, pArg, + auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg, zeroPValue); // bool isMin{p == std::numeric_limits::min()}; auto pIsMin = - builder.create(arith::CmpIPredicate::eq, pArg, minPValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue); // if (isMin) { // p = std::numeric_limits::max(); // } else if (isNegativePower) { // p = -p; // } - Value negP = builder.create(zeroPValue, pArg); - auto pInit = builder.create(pIsNeg, negP, pArg); - pInit = builder.create(pIsMin, maxPValue, pInit); + Value negP = arith::SubIOp::create(builder, zeroPValue, pArg); + auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg); + pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit); // Tb result = Tb{1}; // Tb origBase = Tb{b}; @@ -489,7 +489,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set initial values of 'result', 'b' and 'p' for the loop. builder.setInsertionPointToEnd(pInit->getBlock()); - builder.create(loopHeader, ValueRange{oneBValue, bArg, pInit}); + cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit}); // Create loop body. Value resultTmp = loopHeader->getArgument(0); @@ -498,29 +498,29 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, builder.setInsertionPointToEnd(loopHeader); // if (p & Tp{1}) - auto powerTmpIsOdd = builder.create( + auto powerTmpIsOdd = arith::CmpIOp::create(builder, arith::CmpIPredicate::ne, - builder.create(powerTmp, onePValue), zeroPValue); + arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue); thenBlock = builder.createBlock(funcBody); // result *= b; - Value newResultTmp = builder.create(resultTmp, baseTmp); + Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); - builder.create(newResultTmp, fallthroughBlock); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); // Set up conditional branch for (p & Tp{1}). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); - builder.create(powerTmpIsOdd, thenBlock, fallthroughBlock, + cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock, resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= Tp{1}; builder.setInsertionPointToEnd(fallthroughBlock); - Value newPowerTmp = builder.create(powerTmp, onePValue); + Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue); // if (p == Tp{0}) - auto newPowerIsZero = builder.create(arith::CmpIPredicate::eq, + auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, newPowerTmp, zeroPValue); // break; // @@ -531,9 +531,9 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); - Value newBaseTmp = builder.create(baseTmp, baseTmp); + Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. - builder.create( + cf::BranchOp::create(builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); // Set up conditional branch for early loop exit: @@ -542,7 +542,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); - builder.create(newPowerIsZero, loopExit, newResultTmp, + cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp, fallthroughBlock, ValueRange{}); // if (isMin) { @@ -553,11 +553,11 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(loopExit); - builder.create(pIsMin, thenBlock, fallthroughBlock, + cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock, newResultTmp); builder.setInsertionPointToEnd(thenBlock); - newResultTmp = builder.create(newResultTmp, bArg); - builder.create(newResultTmp, fallthroughBlock); + newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); /// if (isNegativePower) { /// result = Tb{1} / result; @@ -567,15 +567,15 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(fallthroughBlock); - builder.create(pIsNeg, thenBlock, returnBlock, + cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock, newResultTmp); builder.setInsertionPointToEnd(thenBlock); - newResultTmp = builder.create(oneBValue, newResultTmp); - builder.create(newResultTmp, returnBlock); + newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp); + cf::BranchOp::create(builder, newResultTmp, returnBlock); // return result; builder.setInsertionPointToEnd(returnBlock); - builder.create(returnBlock->getArgument(0)); + func::ReturnOp::create(builder, returnBlock->getArgument(0)); return funcOp; } @@ -667,7 +667,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { nameOS << '_' << elementType; FunctionType funcType = FunctionType::get(builder.getContext(), {elementType}, elementType); - auto funcOp = builder.create(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); // LinkonceODR ensures that there is only one implementation of this function // across all math.ctlz functions that are lowered in this way. @@ -683,32 +683,32 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { Value arg = funcOp.getArgument(0); Type indexType = builder.getIndexType(); - Value bitWidthValue = builder.create( + Value bitWidthValue = arith::ConstantOp::create(builder, elementType, builder.getIntegerAttr(elementType, bitWidth)); - Value zeroValue = builder.create( + Value zeroValue = arith::ConstantOp::create(builder, elementType, builder.getIntegerAttr(elementType, 0)); Value inputEqZero = - builder.create(arith::CmpIPredicate::eq, arg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue); // if input == 0, return bit width, else enter loop. - scf::IfOp ifOp = builder.create( + scf::IfOp ifOp = scf::IfOp::create(builder, elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); ifOp.getThenBodyBuilder().create(loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); - Value oneIndex = elseBuilder.create( + Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType, elseBuilder.getIndexAttr(1)); - Value oneValue = elseBuilder.create( + Value oneValue = arith::ConstantOp::create(elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1)); - Value bitWidthIndex = elseBuilder.create( + Value bitWidthIndex = arith::ConstantOp::create(elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth)); - Value nValue = elseBuilder.create( + Value nValue = arith::ConstantOp::create(elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0)); - auto loop = elseBuilder.create( + auto loop = scf::ForOp::create(elseBuilder, oneIndex, bitWidthIndex, oneIndex, // Initial values for two loop induction variables, the arg which is being // shifted left in each iteration, and the n value which tracks the count @@ -725,25 +725,25 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { Value argIter = args[0]; Value nIter = args[1]; - Value argIsNonNegative = b.create( + Value argIsNonNegative = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, argIter, zeroValue); - scf::IfOp ifOp = b.create( + scf::IfOp ifOp = scf::IfOp::create(b, loc, argIsNonNegative, [&](OpBuilder &b, Location loc) { // If arg is negative, continue (effectively, break) - b.create(loc, ValueRange{argIter, nIter}); + scf::YieldOp::create(b, loc, ValueRange{argIter, nIter}); }, [&](OpBuilder &b, Location loc) { // Otherwise, increment n and shift arg left. - Value nNext = b.create(loc, nIter, oneValue); - Value argNext = b.create(loc, argIter, oneValue); - b.create(loc, ValueRange{argNext, nNext}); + Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue); + Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue); + scf::YieldOp::create(b, loc, ValueRange{argNext, nNext}); }); - b.create(loc, ifOp.getResults()); + scf::YieldOp::create(b, loc, ifOp.getResults()); }); - elseBuilder.create(loop.getResult(1)); + scf::YieldOp::create(elseBuilder, loop.getResult(1)); - builder.create(ifOp.getResult(0)); + func::ReturnOp::create(builder, ifOp.getResult(0)); return funcOp; } diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index f4d69ce8235bb..26559f394c221 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -107,7 +107,7 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), typeConverter, [&](Type llvm1DVectorTy, ValueRange operands) { - return rewriter.create(loc, llvm1DVectorTy, operands[0], + return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0], false); }, rewriter); @@ -145,14 +145,14 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { if (!isa(llvmOperandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(llvmOperandType)) { - one = rewriter.create( + one = LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast(llvmOperandType), floatOne)); } else { - one = rewriter.create(loc, llvmOperandType, floatOne); + one = LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto exp = rewriter.create(loc, adaptor.getOperand(), + auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(), expAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs()); @@ -171,10 +171,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { {numElements.isScalable()}), floatOne); auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto exp = rewriter.create( + LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, splatAttr); + auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); - return rewriter.create( + return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); }, rewriter); @@ -205,14 +205,14 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { if (!isa(llvmOperandType)) { LLVM::ConstantOp one = isa(llvmOperandType) - ? rewriter.create( + ? LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast(llvmOperandType), floatOne)) - : rewriter.create(loc, llvmOperandType, + : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); - auto add = rewriter.create( + auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType, ValueRange{one, adaptor.getOperand()}, addAttrs.getAttrs()); rewriter.replaceOpWithNewOp( @@ -232,11 +232,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { {numElements.isScalable()}), floatOne); auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create(loc, llvm1DVectorTy, + LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, splatAttr); + auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy, ValueRange{one, operands[0]}, addAttrs.getAttrs()); - return rewriter.create( + return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); }, rewriter); @@ -267,14 +267,14 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { if (!isa(llvmOperandType)) { LLVM::ConstantOp one; if (isa(llvmOperandType)) { - one = rewriter.create( + one = LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast(llvmOperandType), floatOne)); } else { - one = rewriter.create(loc, llvmOperandType, floatOne); + one = LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto sqrt = rewriter.create(loc, adaptor.getOperand(), + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(), sqrtAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); @@ -293,10 +293,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { {numElements.isScalable()}), floatOne); auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto sqrt = rewriter.create( + LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, splatAttr); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); - return rewriter.create( + return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); }, rewriter); diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index a0ce7d3b75fc2..19303dc3e8b55 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -84,7 +84,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto shape = vecType.getShape(); int64_t numElements = vecType.getNumElements(); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); SmallVector strides = computeStrides(shape); @@ -93,11 +93,11 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { SmallVector operands; for (auto input : op->getOperands()) operands.push_back( - rewriter.create(loc, input, positions)); + vector::ExtractOp::create(rewriter, loc, input, positions)); Value scalarOp = - rewriter.create(loc, vecType.getElementType(), operands); + Op::create(rewriter, loc, vecType.getElementType(), operands); result = - rewriter.create(loc, scalarOp, result, positions); + vector::InsertOp::create(rewriter, loc, scalarOp, result, positions); } rewriter.replaceOp(op, {result}); return success(); @@ -114,9 +114,9 @@ PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto f32 = rewriter.getF32Type(); auto extendedOperands = llvm::to_vector( llvm::map_range(op->getOperands(), [&](Value operand) -> Value { - return rewriter.create(loc, f32, operand); + return arith::ExtFOp::create(rewriter, loc, f32, operand); })); - auto newOp = rewriter.create(loc, f32, extendedOperands); + auto newOp = Op::create(rewriter, loc, f32, extendedOperands); rewriter.replaceOpWithNewOp(op, opType, newOp); return success(); } @@ -139,7 +139,7 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create(rewriter.getUnknownLoc(), name, + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 59db14ed816be..26891aa7c2025 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -36,11 +36,11 @@ static Value getScalarOrVectorI32Constant(Type type, int value, if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector values(vectorType.getNumElements(), value); - return builder.create(loc, type, + return spirv::ConstantOp::create(builder, loc, type, builder.getI32VectorAttr(values)); } if (type.isInteger(32)) - return builder.create(loc, type, + return spirv::ConstantOp::create(builder, loc, type, builder.getI32IntegerAttr(value)); return nullptr; @@ -144,9 +144,9 @@ struct CopySignPattern final : public OpConversionPattern { Type intType = rewriter.getIntegerType(bitwidth); uint64_t intValue = uint64_t(1) << (bitwidth - 1); - Value signMask = rewriter.create( + Value signMask = spirv::ConstantOp::create(rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue)); - Value valueMask = rewriter.create( + Value valueMask = spirv::ConstantOp::create(rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); if (auto vectorType = dyn_cast(type)) { @@ -156,24 +156,24 @@ struct CopySignPattern final : public OpConversionPattern { SmallVector signSplat(count, signMask); signMask = - rewriter.create(loc, intType, signSplat); + spirv::CompositeConstructOp::create(rewriter, loc, intType, signSplat); SmallVector valueSplat(count, valueMask); - valueMask = rewriter.create(loc, intType, + valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, valueSplat); } Value lhsCast = - rewriter.create(loc, intType, adaptor.getLhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs()); Value rhsCast = - rewriter.create(loc, intType, adaptor.getRhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs()); - Value value = rewriter.create( + Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType, ValueRange{lhsCast, valueMask}); - Value sign = rewriter.create( + Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType, ValueRange{rhsCast, signMask}); - Value result = rewriter.create(loc, intType, + Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType, ValueRange{value, sign}); rewriter.replaceOpWithNewOp(copySignOp, type, result); return success(); @@ -214,18 +214,18 @@ struct CountLeadingZerosPattern final Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); - Value msb = rewriter.create(loc, input); + Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input); // We need to subtract from 31 given that the index returned by GLSL // FindUMsb is counted from the least significant bit. Theoretically this // also gives the correct result even if the integer has all zero bits, in // which case GL FindUMsb would return -1. - Value subMsb = rewriter.create(loc, val31, msb); + Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb); // However, certain Vulkan implementations have driver bugs for the corner // case where the input is zero. And.. it can be smart to optimize a select // only involving the corner case. So separately compute the result when the // input is either zero or one. - Value subInput = rewriter.create(loc, val32, input); - Value cmp = rewriter.create(loc, input, val1); + Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input); + Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1); rewriter.replaceOpWithNewOp(countOp, cmp, subInput, subMsb); return success(); @@ -253,7 +253,7 @@ struct ExpM1OpPattern final : public OpConversionPattern { if (!type) return failure(); - Value exp = rewriter.create(loc, type, adaptor.getOperand()); + Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp(operation, exp, one); return success(); @@ -283,7 +283,7 @@ struct Log1pOpPattern final : public OpConversionPattern { auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); Value onePlus = - rewriter.create(loc, one, adaptor.getOperand()); + spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } @@ -321,14 +321,14 @@ struct Log2Log10OpPattern final : public OpConversionPattern { auto getConstantValue = [&](double value) { if (auto floatType = dyn_cast(type)) { - return rewriter.create( + return spirv::ConstantOp::create(rewriter, loc, type, rewriter.getFloatAttr(floatType, value)); } if (auto vectorType = dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (isa(elemType)) { - return rewriter.create( + return spirv::ConstantOp::create(rewriter, loc, type, DenseFPElementsAttr::get( vectorType, FloatAttr::get(elemType, value).getValue())); @@ -341,7 +341,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern { Value constantValue = getConstantValue( std::is_same() ? log2Reciprocal : log10Reciprocal); - Value log = rewriter.create(loc, adaptor.getOperand()); + Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, log, constantValue); return success(); @@ -386,7 +386,7 @@ struct PowFOpPattern final : public OpConversionPattern { Location loc = powfOp.getLoc(); Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter); Value lessThan = - rewriter.create(loc, adaptor.getLhs(), zero); + spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero); // Per C/C++ spec: // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is @@ -394,11 +394,11 @@ struct PowFOpPattern final : public OpConversionPattern { // Calculate the reminder from the exponent and check whether it is zero. Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter); Value expRem = - rewriter.create(loc, adaptor.getRhs(), floatOne); + spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne); Value expRemNonZero = - rewriter.create(loc, expRem, zero); + spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero); Value cmpNegativeWithFractionalExp = - rewriter.create(loc, expRemNonZero, lessThan); + spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan); // Create NaN result and replace base value if conditions are met. const auto &floatSemantics = scalarFloatType.getFloatSemantics(); const auto nan = APFloat::getNaN(floatSemantics); @@ -407,10 +407,10 @@ struct PowFOpPattern final : public OpConversionPattern { nanAttr = DenseElementsAttr::get(vectorType, nan); Value NanValue = - rewriter.create(loc, operandType, nanAttr); - Value lhs = rewriter.create( + spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); + Value lhs = spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs()); - Value abs = rewriter.create(loc, lhs); + Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in // order to properly propagate the sign, assuming integer y cases. It @@ -418,18 +418,18 @@ struct PowFOpPattern final : public OpConversionPattern { // Cast exponent to integer and calculate exponent % 2 != 0. Value intRhs = - rewriter.create(loc, intType, adaptor.getRhs()); + spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs()); Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); Value bitwiseAndOne = - rewriter.create(loc, intRhs, intOne); - Value isOdd = rewriter.create(loc, bitwiseAndOne, intOne); + spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne); + Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne); // calculate pow based on abs(lhs)^rhs. - Value pow = rewriter.create(loc, abs, adaptor.getRhs()); - Value negate = rewriter.create(loc, pow); + Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs()); + Value negate = spirv::FNegateOp::create(rewriter, loc, pow); // if the exponent is odd and lhs < 0, negate the result. Value shouldNegate = - rewriter.create(loc, lessThan, isOdd); + spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd); rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, pow); return success(); @@ -455,22 +455,22 @@ struct RoundOpPattern final : public OpConversionPattern { auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; if (VectorType vty = dyn_cast(ty)) { - half = rewriter.create( + half = spirv::ConstantOp::create(rewriter, loc, vty, DenseElementsAttr::get(vty, rewriter.getFloatAttr(ety, 0.5).getValue())); } else { - half = rewriter.create( + half = spirv::ConstantOp::create(rewriter, loc, ty, rewriter.getFloatAttr(ety, 0.5)); } - auto abs = rewriter.create(loc, operand); - auto floor = rewriter.create(loc, abs); - auto sub = rewriter.create(loc, abs, floor); + auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand); + auto floor = spirv::GLFloorOp::create(rewriter, loc, abs); + auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor); auto greater = - rewriter.create(loc, sub, half); - auto select = rewriter.create(loc, greater, one, zero); - auto add = rewriter.create(loc, floor, select); + spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half); + auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero); + auto add = spirv::FAddOp::create(rewriter, loc, floor, select); rewriter.replaceOpWithNewOp(roundOp, add, operand); return success(); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index db244d1d1cac8..db6ce600af76a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -161,7 +161,7 @@ struct ConvertLoad final : public OpConversionPattern { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create( + auto subscript = emitc::SubscriptOp::create(rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, resultTy, subscript); @@ -181,7 +181,7 @@ struct ConvertStore final : public OpConversionPattern { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create( + auto subscript = emitc::SubscriptOp::create(rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, subscript, operands.getValue()); @@ -212,7 +212,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { if (inputs.size() != 1) return Value(); - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 00bbdcb12e326..1603b1ba7eb9e 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -87,12 +87,12 @@ getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, /// aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { - Value one = rewriter.create(loc, alignment.getType(), + Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(), rewriter.getIndexAttr(1)); - Value bump = rewriter.create(loc, alignment, one); - Value bumped = rewriter.create(loc, input, bump); - Value mod = rewriter.create(loc, bumped, alignment); - return rewriter.create(loc, bumped, mod); + Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one); + Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump); + Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment); + return LLVM::SubOp::create(rewriter, loc, bumped, mod); } /// Computes the byte size for the MemRef element type. @@ -123,7 +123,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space"); unsigned memrefAddrSpace = *maybeMemrefAddrSpace; if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) - allocatedPtr = rewriter.create( + allocatedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), allocatedPtr); return allocatedPtr; @@ -168,14 +168,14 @@ class AllocOpLowering : public ConvertOpToLLVMPattern { Value alignment = getAlignment(rewriter, loc, op); if (alignment) { // Adjust the allocation size to consider alignment. - sizeBytes = rewriter.create(loc, sizeBytes, alignment); + sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment); } // Allocate the underlying buffer. Type elementPtrType = this->getElementPtrType(memRefType); assert(elementPtrType && "could not compute element ptr type"); auto results = - rewriter.create(loc, allocFuncOp.value(), sizeBytes); + LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -184,11 +184,11 @@ class AllocOpLowering : public ConvertOpToLLVMPattern { if (alignment) { // Compute the aligned pointer. Value allocatedInt = - rewriter.create(loc, getIndexType(), allocatedPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = - rewriter.create(loc, elementPtrType, alignmentInt); + LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt); } // Create the MemRef descriptor. @@ -268,7 +268,7 @@ class AlignedAllocOpLowering : public ConvertOpToLLVMPattern { sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - auto results = rewriter.create( + auto results = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); Value ptr = @@ -360,7 +360,7 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern { auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); - auto allocatedElementPtr = rewriter.create( + auto allocatedElementPtr = LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size, op.getAlignment().value_or(0)); // Create the MemRef descriptor. @@ -397,7 +397,7 @@ struct AllocaScopeOpLowering remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); - rewriter.create(loc, ValueRange(), remainingOpsBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock); } // Inline body region. @@ -407,8 +407,8 @@ struct AllocaScopeOpLowering // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); - auto stackSaveOp = rewriter.create(loc, getPtrType()); - rewriter.create(loc, ValueRange(), beforeBody); + auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); + LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -420,7 +420,7 @@ struct AllocaScopeOpLowering // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); - rewriter.create(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); @@ -451,10 +451,10 @@ struct AssumeAlignmentOpLowering // This is more direct than ptrtoint-based checks, is explicitly supported, // and works with non-integral address spaces. Value trueCond = - rewriter.create(loc, rewriter.getBoolAttr(true)); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); Value alignmentConst = createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); - rewriter.create(loc, trueCond, LLVM::AssumeAlignTag(), ptr, + LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr, alignmentConst); rewriter.replaceOp(op, memref); return success(); @@ -559,16 +559,16 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { // Get pointer to offset field of memref descriptor. auto indexPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); - Value offsetPtr = rewriter.create( + Value offsetPtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType, underlyingRankedDesc, ArrayRef{0, 2}); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. - Value idxPlusOne = rewriter.create( + Value idxPlusOne = LLVM::AddOp::create(rewriter, loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), adaptor.getIndex()); - Value sizePtr = rewriter.create( + Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); return rewriter @@ -674,9 +674,9 @@ struct GenericAtomicRMWOpLowering auto memRefType = cast(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr( rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices()); - Value init = rewriter.create( + Value init = LLVM::LoadOp::create(rewriter, loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); - rewriter.create(loc, init, loopBlock); + LLVM::BrOp::create(rewriter, loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); @@ -696,14 +696,14 @@ struct GenericAtomicRMWOpLowering // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto cmpxchg = rewriter.create( + auto cmpxchg = LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create(loc, cmpxchg, 0); - Value ok = rewriter.create(loc, cmpxchg, 1); + Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0); + Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1); // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create(loc, ok, endBlock, ArrayRef(), + LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); @@ -800,8 +800,8 @@ class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { if (!global.isExternal() && global.isUninitialized()) { rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { - rewriter.create(global.getLoc(), arrayTy)}; - rewriter.create(global.getLoc(), undef); + LLVM::UndefOp::create(rewriter, global.getLoc(), arrayTy)}; + LLVM::ReturnOp::create(rewriter, global.getLoc(), undef); } return success(); } @@ -846,11 +846,11 @@ struct GetGlobalMemrefOpLowering Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); auto addressOf = - rewriter.create(loc, ptrTy, op.getName()); + LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. - auto gep = rewriter.create( + auto gep = LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf, SmallVector(type.getRank() + 1, 0)); @@ -861,7 +861,7 @@ struct GetGlobalMemrefOpLowering Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = - rewriter.create(loc, ptrTy, deadBeefConst); + LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. @@ -1013,7 +1013,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { loc, adaptor.getSource(), rewriter); // rank = ConstantOp srcRank - auto rankVal = rewriter.create( + auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(rank)); // poison = PoisonOp UnrankedMemRefDescriptor memRefDesc = @@ -1033,7 +1033,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // struct = LoadOp ptr - auto loadOp = rewriter.create(loc, targetStructType, ptr); + auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); @@ -1067,31 +1067,31 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { MemRefDescriptor srcDesc(adaptor.getSource()); // Compute number of elements. - Value numElements = rewriter.create( + Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(1)); for (int pos = 0; pos < srcType.getRank(); ++pos) { auto size = srcDesc.size(rewriter, loc, pos); - numElements = rewriter.create(loc, numElements, size); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } // Get element size. auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); // Compute total. Value totalSize = - rewriter.create(loc, numElements, sizeInBytes); + LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes); Type elementType = typeConverter->convertType(srcType.getElementType()); Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); Value srcOffset = srcDesc.offset(rewriter, loc); - Value srcPtr = rewriter.create( + Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); MemRefDescriptor targetDesc(adaptor.getTarget()); Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); Value targetOffset = targetDesc.offset(rewriter, loc); - Value targetPtr = rewriter.create( + Value targetPtr = LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); - rewriter.create(loc, targetPtr, srcPtr, totalSize, + LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize, /*isVolatile=*/false); rewriter.eraseOp(op); @@ -1107,7 +1107,7 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { - auto rank = rewriter.create(loc, getIndexType(), + auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), type.getRank()); auto *typeConverter = getTypeConverter(); auto ptr = @@ -1120,7 +1120,7 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { }; // Save stack position before promoting descriptors - auto stackSaveOp = rewriter.create(loc, getPtrType()); + auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); auto srcMemRefType = dyn_cast(srcType); Value unrankedSource = @@ -1132,13 +1132,13 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { : adaptor.getTarget(); // Now promote the unranked descriptors to the stack. - auto one = rewriter.create(loc, getIndexType(), + auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(1)); auto promote = [&](Value desc) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); auto allocated = - rewriter.create(loc, ptrType, desc.getType(), one); - rewriter.create(loc, desc, allocated); + LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one); + LLVM::StoreOp::create(rewriter, loc, desc, allocated); return allocated; }; @@ -1153,11 +1153,11 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { sourcePtr.getType(), symbolTables); if (failed(copyFn)) return failure(); - rewriter.create(loc, copyFn.value(), + LLVM::CallOp::create(rewriter, loc, copyFn.value(), ValueRange{elemSize, sourcePtr, targetPtr}); // Restore stack used for descriptors - rewriter.create(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); rewriter.eraseOp(op); @@ -1208,9 +1208,9 @@ struct MemorySpaceCastOpLowering MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, descVals); descVals[0] = - rewriter.create(loc, newPtrType, descVals[0]); + LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]); descVals[1] = - rewriter.create(loc, newPtrType, descVals[1]); + LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]); Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), resultTypeR, descVals); rewriter.replaceOp(op, result); @@ -1245,7 +1245,7 @@ struct MemorySpaceCastOpLowering UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), result, resultAddrSpace, sizes); Value resultUnderlyingSize = sizes.front(); - Value resultUnderlyingDesc = rewriter.create( + Value resultUnderlyingDesc = LLVM::AllocaOp::create(rewriter, loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); @@ -1260,9 +1260,9 @@ struct MemorySpaceCastOpLowering Value alignedPtr = sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), sourceUnderlyingDesc, sourceElemPtrType); - allocatedPtr = rewriter.create( + allocatedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, resultElemPtrType, allocatedPtr); - alignedPtr = rewriter.create( + alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, resultElemPtrType, alignedPtr); result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, @@ -1281,11 +1281,11 @@ struct MemorySpaceCastOpLowering int64_t bytesToSkip = 2 * llvm::divideCeil( getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); - Value bytesToSkipConst = rewriter.create( + Value bytesToSkipConst = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); - Value copySize = rewriter.create( + Value copySize = LLVM::SubOp::create(rewriter, loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); - rewriter.create(loc, resultIndexVals, sourceIndexVals, + LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals, copySize, /*isVolatile=*/false); rewriter.replaceOp(op, ValueRange{result}); @@ -1489,7 +1489,7 @@ struct MemRefReshapeOpLowering } else { Value shapeOp = reshapeOp.getShape(); Value index = createIndexAttrConstant(rewriter, loc, indexType, i); - dimSize = rewriter.create(loc, shapeOp, index); + dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index); Type indexType = getIndexType(); if (dimSize.getType() != indexType) dimSize = typeConverter->materializeTargetConversion( @@ -1501,7 +1501,7 @@ struct MemRefReshapeOpLowering desc.setStride(rewriter, loc, i, stride); // Prepare the stride value for the next dimension. - stride = rewriter.create(loc, stride, dimSize); + stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize); } *descriptor = desc; @@ -1526,7 +1526,7 @@ struct MemRefReshapeOpLowering SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, addressSpace, sizes); - Value underlyingDescPtr = rewriter.create( + Value underlyingDescPtr = LLVM::AllocaOp::create(rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), sizes.front()); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); @@ -1558,7 +1558,7 @@ struct MemRefReshapeOpLowering Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); Value resultRankMinusOne = - rewriter.create(loc, resultRank, oneIndex); + LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); Type indexType = getTypeConverter()->getIndexType(); @@ -1572,14 +1572,14 @@ struct MemRefReshapeOpLowering rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); rewriter.setInsertionPointToEnd(initBlock); - rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), + LLVM::BrOp::create(rewriter, loc, ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); - Value pred = rewriter.create( + Value pred = LLVM::ICmpOp::create(rewriter, loc, IntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); @@ -1589,22 +1589,22 @@ struct MemRefReshapeOpLowering // Copy size from shape to descriptor. auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value sizeLoadGep = rewriter.create( + Value sizeLoadGep = LLVM::GEPOp::create(rewriter, loc, llvmIndexPtrType, typeConverter->convertType(shapeMemRefType.getElementType()), shapeOperandPtr, indexArg); - Value size = rewriter.create(loc, indexType, sizeLoadGep); + Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); - Value nextStride = rewriter.create(loc, strideArg, size); + Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size); // Decrement loop counter and branch back. - Value decrement = rewriter.create(loc, indexArg, oneIndex); - rewriter.create(loc, ValueRange({decrement, nextStride}), + Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex); + LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}), condBlock); Block *remainder = @@ -1612,7 +1612,7 @@ struct MemRefReshapeOpLowering // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, pred, bodyBlock, ValueRange(), + LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(), remainder, ValueRange()); // Reset position to beginning of new remainder block. @@ -1742,7 +1742,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); if (nextSize) return runningStride - ? rewriter.create(loc, runningStride, nextSize) + ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexAttrConstant(rewriter, loc, indexType, 1); @@ -1787,7 +1787,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); - alignedPtr = rewriter.create( + alignedPtr = LLVM::GEPOp::create(rewriter, loc, alignedPtr.getType(), typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, adaptor.getByteShift()); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index b866afbce98b0..7b7ad9b95b0ec 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -79,7 +79,7 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, assert(indices.size() == 2); indices.back() = builder.createOrFold(loc, lastDim, idx); Type t = typeConverter.convertType(op.getComponentPtr().getType()); - return builder.create(loc, t, op.getBasePtr(), indices); + return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(), indices); } /// Casts the given `srcBool` into an integer of `dstType`. @@ -107,7 +107,7 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask, value = castBoolToIntN(loc, value, dstType, builder); } else { if (valueBits < targetBits) { - value = builder.create( + value = spirv::UConvertOp::create(builder, loc, builder.getIntegerType(targetBits), value); } @@ -372,7 +372,7 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); - varOp = rewriter.create(loc, spirvType, varName, + varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName, /*initializer=*/nullptr); } @@ -572,7 +572,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value loadVal = rewriter.create(loc, accessChain, + Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain, memoryAccess, alignment); if (isBool) loadVal = castIntNToBool(loc, loadVal, rewriter); @@ -601,7 +601,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value spvLoadOp = rewriter.create(loc, dstType, adjustedPtr, + Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr, memoryAccess, alignment); // Shift the bits to the rightmost. @@ -770,10 +770,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, if (!scope) return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); - Value result = rewriter.create( + Value result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, clearBitsMask); - result = rewriter.create( + result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, storeVal); @@ -851,11 +851,11 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( } if (sourceSc != spirv::StorageClass::Generic) { result = - rewriter.create(loc, genericPtrType, result); + spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType, result); } if (resultSc != spirv::StorageClass::Generic) { result = - rewriter.create(loc, resultPtrType, result); + spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result); } rewriter.replaceOp(addrCastOp, result); return success(); diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index b93128441f2b5..f16b92a287d0a 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -65,7 +65,7 @@ static SmallVector getMixedAsValues(OpBuilder b, const Location &loc, values.emplace_back(*(dyn++)); } else { TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); - values.emplace_back(b.create(loc, type, val)); + values.emplace_back(arith::ConstantOp::create(b, loc, type, val)); } } return values; @@ -79,9 +79,9 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, SmallVector multiIndex(n); for (int i = n - 1; i >= 0; --i) { - multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); + multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]); if (i > 0) - linearIndex = b.create(loc, linearIndex, dimensions[i]); + linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]); } return multiIndex; @@ -91,13 +91,13 @@ static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { - Value linearIndex = b.create(loc, 0); - Value stride = b.create(loc, 1); + Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0); + Value stride = arith::ConstantIndexOp::create(b, loc, 1); for (int i = multiIndex.size() - 1; i >= 0; --i) { - Value off = b.create(loc, multiIndex[i], stride); - linearIndex = b.create(loc, linearIndex, off); - stride = b.create(loc, stride, dimensions[i]); + Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride); + linearIndex = arith::AddIOp::create(b, loc, linearIndex, off); + stride = arith::MulIOp::create(b, loc, stride, dimensions[i]); } return linearIndex; @@ -144,10 +144,10 @@ struct ConvertShardingOp : public OpConversionPattern { auto i64 = rewriter.getI64Type(); std::array shape = {static_cast(splitAxes.size()), maxNAxes}; - Value resSplitAxes = rewriter.create(loc, shape, i16); + Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16); auto attr = IntegerAttr::get(i16, -1); - Value fillValue = rewriter.create(loc, i16, attr); - resSplitAxes = rewriter.create(loc, fillValue, resSplitAxes) + Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr); + resSplitAxes = linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes) .getResult(0); // explicitly write values into tensor row by row @@ -162,8 +162,8 @@ struct ConvertShardingOp : public OpConversionPattern { std::array sizes = {1, size}; auto tensorType = RankedTensorType::get({size}, i16); auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); - auto vals = rewriter.create(loc, tensorType, attrs); - resSplitAxes = rewriter.create( + auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs); + resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); } @@ -179,7 +179,7 @@ struct ConvertShardingOp : public OpConversionPattern { .create(loc, std::array{0, 0}, i64) .getResult() - : rewriter.create(loc, type, haloSizes) + : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); // To hold sharded dims offsets, create Tensor with shape {nSplits, @@ -189,7 +189,7 @@ struct ConvertShardingOp : public OpConversionPattern { // MeshOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { - resOffsets = rewriter.create( + resOffsets = tensor::EmptyOp::create(rewriter, loc, std::array{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; @@ -204,12 +204,12 @@ struct ConvertShardingOp : public OpConversionPattern { assert(maxSplitSize); ++maxSplitSize; // add one for the total size - resOffsets = rewriter.create( + resOffsets = tensor::EmptyOp::create(rewriter, loc, std::array{nSplits, maxSplitSize}, i64); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); resOffsets = - rewriter.create(loc, zero, resOffsets).getResult(0); + linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0); SmallVector offsets = getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), adaptor.getDynamicShardedDimsOffsets()); @@ -220,10 +220,10 @@ struct ConvertShardingOp : public OpConversionPattern { assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef values(&offsets[curr], splitSize); - Value vals = rewriter.create(loc, values); + Value vals = tensor::FromElementsOp::create(rewriter, loc, values); std::array offs = {static_cast(i), 0}; std::array sizes = {1, splitSize}; - resOffsets = rewriter.create( + resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); curr += splitSize; } @@ -236,10 +236,10 @@ struct ConvertShardingOp : public OpConversionPattern { return failure(); resSplitAxes = - rewriter.create(loc, resTypes[0], resSplitAxes); + tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes); resHaloSizes = - rewriter.create(loc, resTypes[1], resHaloSizes); - resOffsets = rewriter.create(loc, resTypes[2], resOffsets); + tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes); + resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets); rewriter.replaceOpWithNewOp( op, TupleType::get(op.getContext(), resTypes), @@ -269,9 +269,9 @@ struct ConvertProcessMultiIndexOp SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create(loc, i).getResult(); + return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = rewriter.create(op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); // optionally extract subset of mesh axes @@ -302,7 +302,7 @@ class ConvertProcessLinearIndexOp Location loc = op.getLoc(); auto ctx = op.getContext(); Value commWorld = - rewriter.create(loc, mpi::CommType::get(ctx)); + mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); auto rank = rewriter .create( @@ -341,40 +341,40 @@ struct ConvertNeighborsLinearIndicesOp SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create(loc, i).getResult(); + return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; - Value one = rewriter.create(loc, 1); - Value minus1 = rewriter.create(loc, -1); - Value atBorder = rewriter.create( + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1); + Value atBorder = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx, - rewriter.create(loc, 0)); - auto down = rewriter.create( + arith::ConstantIndexOp::create(rewriter, loc, 0)); + auto down = scf::IfOp::create(rewriter, loc, atBorder, [&](OpBuilder &builder, Location loc) { - builder.create(loc, minus1); + scf::YieldOp::create(builder, loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, one) + arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one) .getResult(); - builder.create( + scf::YieldOp::create(builder, loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); - atBorder = rewriter.create( + atBorder = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, orgIdx, - rewriter.create(loc, dimSz, one).getResult()); - auto up = rewriter.create( + arith::SubIOp::create(rewriter, loc, dimSz, one).getResult()); + auto up = scf::IfOp::create(rewriter, loc, atBorder, [&](OpBuilder &builder, Location loc) { - builder.create(loc, minus1); + scf::YieldOp::create(builder, loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = - rewriter.create(op.getLoc(), orgIdx, one); - builder.create( + arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one); + scf::YieldOp::create(builder, loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); @@ -447,7 +447,7 @@ struct ConvertShardShapeOp : public OpConversionPattern { rewriter, loc, sharding.getStaticShardedDimsOffsets(), sharding.getDynamicShardedDimsOffsets(), index); if (!tmp.empty()) - shardedDimsOffs = rewriter.create( + shardedDimsOffs = tensor::FromElementsOp::create(rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); } @@ -457,9 +457,9 @@ struct ConvertShardShapeOp : public OpConversionPattern { int64_t pos = 0; SmallVector shardShape; Value zero = - rewriter.create(loc, rewriter.getZeroAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index)); Value one = - rewriter.create(loc, rewriter.getOneAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index)); // Iterate over the dimensions of the tensor shape, get their split Axes, // and compute the sharded shape. @@ -470,7 +470,7 @@ struct ConvertShardShapeOp : public OpConversionPattern { // The current dimension might not be sharded. // Create a value from the static position in shardDimsOffsets. Value posVal = - rewriter.create(loc, rewriter.getIndexAttr(pos)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(pos)); // Get the index of the local shard in the mesh axis. Value idx = multiIdx[axes[0]]; auto numShards = @@ -482,29 +482,29 @@ struct ConvertShardShapeOp : public OpConversionPattern { return op->emitError() << "Only single axis sharding is " << "supported for each dimension."; } - idx = rewriter.create(loc, posVal, idx); + idx = arith::AddIOp::create(rewriter, loc, posVal, idx); // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. Value off = - rewriter.create(loc, shardedDimsOffs, idx); - idx = rewriter.create(loc, idx, one); + tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); + idx = arith::AddIOp::create(rewriter, loc, idx, one); Value nextOff = - rewriter.create(loc, shardedDimsOffs, idx); - Value sz = rewriter.create(loc, nextOff, off); + tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); + Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off); shardShape.emplace_back(sz); } else { - Value numShardsVal = rewriter.create( + Value numShardsVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(numShards)); // Compute shard dim size by distributing odd elements to trailing // shards: // sz = dim / numShards // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) - Value sz = rewriter.create(loc, dim, numShardsVal); - Value sz1 = rewriter.create(loc, dim, numShardsVal); - sz1 = rewriter.create(loc, numShardsVal, sz1); - auto cond = rewriter.create( + Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal); + Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal); + sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1); + auto cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, idx, sz1); - Value odd = rewriter.create(loc, cond, one, zero); - sz = rewriter.create(loc, sz, odd); + Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero); + sz = arith::AddIOp::create(rewriter, loc, sz, odd); shardShape.emplace_back(sz); } pos += numShards + 1; // add one for the total size. @@ -568,7 +568,7 @@ struct ConvertAllReduceOp : public OpConversionPattern { if (isa(input.getType())) { auto memrefType = MemRefType::get( inputShape, cast(input.getType()).getElementType()); - input = iBuilder.create(memrefType, input); + input = bufferization::ToBufferOp::create(iBuilder, memrefType, input); } MemRefType inType = cast(input.getType()); @@ -577,15 +577,15 @@ struct ConvertAllReduceOp : public OpConversionPattern { for (auto i = 0; i < inType.getRank(); ++i) { auto s = inputShape[i]; if (ShapedType::isDynamic(s)) - shape[i] = iBuilder.create(input, s).getResult(); + shape[i] = memref::DimOp::create(iBuilder, input, s).getResult(); else shape[i] = iBuilder.getIndexAttr(s); } // Allocate buffer and copy input to buffer. - Value buffer = iBuilder.create( + Value buffer = memref::AllocOp::create(iBuilder, shape, cast(op.getType()).getElementType()); - iBuilder.create(input, buffer); + linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. // The color is the linear index of the process in the mesh along the @@ -594,9 +594,9 @@ struct ConvertAllReduceOp : public OpConversionPattern { SmallVector indexResultTypes(meshOp.getShape().size(), iBuilder.getIndexType()); SmallVector myMultiIndex = - iBuilder.create(indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh) .getResult(); - Value zero = iBuilder.create(0); + Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector multiKey(myMultiIndex.size(), zero); auto redAxes = adaptor.getMeshAxes(); @@ -607,15 +607,15 @@ struct ConvertAllReduceOp : public OpConversionPattern { Value color = createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); - color = iBuilder.create(iBuilder.getI32Type(), color); + color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); - key = iBuilder.create(iBuilder.getI32Type(), key); + key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator auto commType = mpi::CommType::get(op->getContext()); - Value commWorld = iBuilder.create(commType); + Value commWorld = mpi::CommWorldOp::create(iBuilder, commType); auto comm = - iBuilder.create(commType, commWorld, color, key) + mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key) .getNewcomm(); Value buffer1d = buffer; @@ -623,18 +623,18 @@ struct ConvertAllReduceOp : public OpConversionPattern { if (inType.getRank() > 1) { ReassociationIndices reassociation(inType.getRank()); std::iota(reassociation.begin(), reassociation.end(), 0); - buffer1d = iBuilder.create( + buffer1d = memref::CollapseShapeOp::create(iBuilder, buffer, ArrayRef(reassociation)); } // Create the MPI AllReduce operation. - iBuilder.create( + mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d, getMPIReductionOp(adaptor.getReductionAttr()), comm); // If the destination is a memref, cast it to a tensor if (isa(op.getType())) - buffer = iBuilder.create(op.getType(), buffer, + buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer, true); rewriter.replaceOp(op, buffer); @@ -676,7 +676,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { if (auto value = dyn_cast(v)) return value; - return rewriter.create( + return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr( cast(cast(v)).getInt())); }; @@ -689,7 +689,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { auto mmemrefType = MemRefType::get( dstShape, cast(array.getType()).getElementType()); array = - rewriter.create(loc, mmemrefType, array); + bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array); } auto rank = cast(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); @@ -713,7 +713,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { for (auto i = 0; i < rank; ++i) { auto s = dstShape[i]; if (ShapedType::isDynamic(s)) - shape[i] = rewriter.create(loc, array, s).getResult(); + shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult(); else shape[i] = rewriter.getIndexAttr(s); @@ -723,12 +723,12 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { offsets[i] = haloSizes[currHaloDim * 2]; // prepare shape and offsets of highest dim's halo exchange - Value _haloSz = rewriter.create( + Value _haloSz = arith::AddIOp::create(rewriter, loc, toValue(haloSizes[currHaloDim * 2]), toValue(haloSizes[currHaloDim * 2 + 1])); // the halo shape of lower dims exlude the halos dimSizes[i] = - rewriter.create(loc, toValue(shape[i]), _haloSz) + arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz) .getResult(); } else { dimSizes[i] = shape[i]; @@ -736,14 +736,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { } auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something - auto tag = rewriter.create(loc, tagAttr); + auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr); auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 - auto zero = rewriter.create(loc, zeroAttr); + auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); SmallVector indexResultTypes(meshOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - rewriter.create(loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -758,19 +758,19 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { splitAxes) .getResults(); // MPI operates on i32... - Value neighbourIDs[2] = {rewriter.create( + Value neighbourIDs[2] = {arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), tmp[0]), - rewriter.create( + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), tmp[1])}; auto lowerRecvOffset = rewriter.getIndexAttr(0); auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); - auto upperRecvOffset = rewriter.create( + auto upperRecvOffset = arith::SubIOp::create(rewriter, loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); - auto upperSendOffset = rewriter.create( + auto upperSendOffset = arith::SubIOp::create(rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); - Value commWorld = rewriter.create( + Value commWorld = mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(op->getContext())); // Make sure we send/recv in a way that does not lead to a dead-lock. @@ -787,37 +787,37 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { // Processes on the mesh borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; - auto hasFrom = rewriter.create( + auto hasFrom = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, from, zero); - auto hasTo = rewriter.create( + auto hasTo = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, to, zero); - auto buffer = rewriter.create( + auto buffer = memref::AllocOp::create(rewriter, loc, dimSizes, cast(array.getType()).getElementType()); // if has neighbor: copy halo data from array to buffer and send - rewriter.create( + scf::IfOp::create(rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) : OpFoldResult(upperSendOffset); - auto subview = builder.create( + auto subview = memref::SubViewOp::create(builder, loc, array, offsets, dimSizes, strides); - builder.create(loc, subview, buffer); - builder.create(loc, TypeRange{}, buffer, tag, to, + memref::CopyOp::create(builder, loc, subview, buffer); + mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to, commWorld); - builder.create(loc); + scf::YieldOp::create(builder, loc); }); // if has neighbor: receive halo data into buffer and copy to array - rewriter.create( + scf::IfOp::create(rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) : OpFoldResult(lowerRecvOffset); - builder.create(loc, TypeRange{}, buffer, tag, from, + mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from, commWorld); - auto subview = builder.create( + auto subview = memref::SubViewOp::create(builder, loc, array, offsets, dimSizes, strides); - builder.create(loc, buffer, subview); - builder.create(loc); + memref::CopyOp::create(builder, loc, buffer, subview); + scf::YieldOp::create(builder, loc); }); - rewriter.create(loc, buffer); + memref::DeallocOp::create(rewriter, loc, buffer); offsets[dim] = orgOffset; }; @@ -825,15 +825,15 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; Value haloSz = dyn_cast(v); if (!haloSz) - haloSz = rewriter.create( + haloSz = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr( cast(cast(v)).getInt())); - auto hasSize = rewriter.create( + auto hasSize = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero); - rewriter.create(loc, hasSize, + scf::IfOp::create(rewriter, loc, hasSize, [&](OpBuilder &builder, Location loc) { genSendRecv(upOrDown > 0); - builder.create(loc); + scf::YieldOp::create(builder, loc); }); }; @@ -852,7 +852,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern { rewriter.replaceOp(op, array); } else { assert(isa(op.getResult().getType())); - rewriter.replaceOp(op, rewriter.create( + rewriter.replaceOp(op, bufferization::ToTensorOp::create(rewriter, loc, op.getResult().getType(), array, /*restrict=*/true, /*writable=*/true)); } diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 80b3d85488495..cd92e1dbba090 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" @@ -53,7 +54,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { assert(llvm::isa(type) && "expected an integer Value"); if (type.getIntOrFloatBitWidth() <= 32) return value; - return b.create(b.getI32Type(), value); + return LLVM::TruncOp::create(b, b.getI32Type(), value); } /// Returns the type for the intrinsic given the vectorResultType of the @@ -113,8 +114,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type f32x1Ty = VectorType::get(1, f32Ty); auto makeConst = [&](int32_t index) -> Value { - return rewriter.create(loc, IntegerType::get(ctx, 32), - rewriter.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32), + rewriter.getI32IntegerAttr(index)); }; if (arrayType) { @@ -126,7 +127,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { Value el = - rewriter.create(loc, intrinsicResult, i); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i); el = rewriter.createOrFold( loc, arrayType.getElementType(), el); elements.push_back(el); @@ -143,24 +144,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { Value vec = - rewriter.create(loc, arrayType.getElementType()); + LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType()); Value x1 = - rewriter.create(loc, intrinsicResult, i * 2); - Value x2 = rewriter.create(loc, intrinsicResult, - i * 2 + 1); - vec = rewriter.create(loc, vec.getType(), vec, - x1, makeConst(0)); - vec = rewriter.create(loc, vec.getType(), vec, - x2, makeConst(1)); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2); + Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, + i * 2 + 1); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x1, makeConst(0)); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x2, makeConst(1)); elements.push_back(vec); } } // Create the final vectorized result. - Value result = rewriter.create(loc, arrayType); + Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType); for (const auto &el : llvm::enumerate(elements)) { - result = rewriter.create(loc, result, el.value(), - el.index()); + result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(), + el.index()); } return result; } @@ -187,7 +188,7 @@ static SmallVector unpackOperandVector(ImplicitLocOpBuilder &b, auto arrayTy = cast(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { - Value toUse = b.create(operand, i); + Value toUse = LLVM::ExtractValueOp::create(b, operand, i); // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. @@ -195,7 +196,7 @@ static SmallVector unpackOperandVector(ImplicitLocOpBuilder &b, arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { - result.push_back(b.create(i32Ty, toUse)); + result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse)); continue; } @@ -208,9 +209,9 @@ static SmallVector unpackOperandVector(ImplicitLocOpBuilder &b, innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { - result.push_back(b.create( - toUse, - b.create(i64Ty, b.getI64IntegerAttr(idx)))); + result.push_back(LLVM::ExtractElementOp::create( + b, toUse, + LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx)))); } continue; } @@ -285,8 +286,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); - Value ldMatrixResult = b.create( - ldMatrixResultType, srcPtr, + Value ldMatrixResult = NVVM::LdMatrixOp::create( + b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row); @@ -296,13 +297,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { // actual vector type (still of width 32b) and repack them into a result // struct. Type finalResultType = typeConverter->convertType(vectorResultType); - Value result = b.create(finalResultType); + Value result = LLVM::PoisonOp::create(b, finalResultType); for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { Value i32Register = - num32BitRegs > 1 ? b.create(ldMatrixResult, i) + num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i) : ldMatrixResult; - Value casted = b.create(innerVectorType, i32Register); - result = b.create(result, casted, i); + Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register); + result = LLVM::InsertValueOp::create(b, result, casted, i); } rewriter.replaceOp(op, result); @@ -375,16 +376,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); - Value intrinsicResult = b.create( - intrinsicResTy, matA, matB, matC, - /*shape=*/gemmShape, - /*b1Op=*/std::nullopt, - /*intOverflow=*/overflow, - /*multiplicandPtxTypes=*/ - std::array{*ptxTypeA, *ptxTypeB}, - /*multiplicandLayouts=*/ - std::array{NVVM::MMALayout::row, - NVVM::MMALayout::col}); + Value intrinsicResult = + NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/std::nullopt, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array{*ptxTypeA, *ptxTypeB}, + /*multiplicandLayouts=*/ + std::array{ + NVVM::MMALayout::row, NVVM::MMALayout::col}); rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, intrinsicResult, rewriter)); @@ -565,15 +566,16 @@ static FailureOr emitMmaSparseSyncOpAsm( llvm::append_range(asmVals, args); asmVals.push_back(indexData); - return b.create( - /*resultTypes=*/intrinsicResultType, - /*operands=*/asmVals, - /*asm_string=*/asmStr, - /*constraints=*/constraintStr, - /*has_side_effects=*/true, - /*is_align_stack=*/false, LLVM::TailCallKind::None, - /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); + return LLVM::InlineAsmOp::create(b, + /*resultTypes=*/intrinsicResultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/constraintStr, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::TailCallKind::None, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); } /// Lowers `nvgpu.mma.sp.sync` to inline assembly. @@ -631,7 +633,7 @@ struct NVGPUMmaSparseSyncLowering return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = - b.create(rewriter.getI32Type(), sparseMetadata); + LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata); FailureOr intrinsicResult = emitMmaSparseSyncOpAsm( b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, @@ -682,7 +684,7 @@ struct NVGPUAsyncCopyLowering // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); - scrPtr = b.create(srcPointerGlobalType, scrPtr); + scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr); int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; @@ -697,13 +699,13 @@ struct NVGPUAsyncCopyLowering // The rest of the DstElements in the destination (shared memory) are // filled with zeros. Value c3I32 = - b.create(b.getI32Type(), b.getI32IntegerAttr(3)); - Value bitwidth = b.create( - b.getI32Type(), + LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3)); + Value bitwidth = LLVM::ConstantOp::create( + b, b.getI32Type(), b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); - Value srcElementsI32 = b.create(b.getI32Type(), srcBytes); - srcBytes = b.create( - b.create(bitwidth, srcElementsI32), c3I32); + Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes); + srcBytes = LLVM::LShrOp::create( + b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32); } // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than // 16 dst bytes. @@ -712,14 +714,15 @@ struct NVGPUAsyncCopyLowering ? NVVM::LoadCacheModifierKind::CG : NVVM::LoadCacheModifierKind::CA; - b.create( - dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), + NVVM::CpAsyncOp::create( + b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), srcBytes); // Drop the result token. - Value zero = b.create( - IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); + Value zero = + LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -733,11 +736,11 @@ struct NVGPUAsyncCreateGroupLowering LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.create(op.getLoc()); + NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc()); // Drop the result token. - Value zero = rewriter.create( - op->getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -753,7 +756,7 @@ struct NVGPUAsyncWaitLowering ConversionPatternRewriter &rewriter) const override { // If numGroup is not present pick 0 as a conservative correct value. int32_t numGroups = adaptor.getNumGroups().value_or(0); - rewriter.create(op.getLoc(), numGroups); + NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups); rewriter.eraseOp(op); return success(); } @@ -771,8 +774,8 @@ struct NVGPUMBarrierCreateLowering SymbolTable symbolTable(moduleOp); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(&moduleOp.front()); - auto global = rewriter.create( - funcOp->getLoc(), "__mbarrier", + auto global = memref::GlobalOp::create( + rewriter, funcOp->getLoc(), "__mbarrier", /*sym_visibility=*/rewriter.getStringAttr("private"), /*type=*/barrierType, /*initial_value=*/ElementsAttr(), @@ -974,7 +977,7 @@ struct NVGPUMBarrierTryWaitParityLowering adaptor.getMbarId(), rewriter); Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = - b.create(b.getI32Type(), adaptor.getPhaseParity()); + LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); if (isMbarrierShared(op.getBarriers().getType())) { rewriter.replaceOpWithNewOp( @@ -1063,16 +1066,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering auto ti64 = b.getIntegerType(64); auto makeConst = [&](uint64_t index) -> Value { - return b.create(ti64, b.getI64IntegerAttr(index)); + return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index)); }; auto shiftLeft = [&](Value value, unsigned shift) -> Value { - return b.create(ti64, value, makeConst(shift)); + return LLVM::ShlOp::create(b, ti64, value, makeConst(shift)); }; auto shiftRight = [&](Value value, unsigned shift) -> Value { - return b.create(ti64, value, makeConst(shift)); + return LLVM::LShrOp::create(b, ti64, value, makeConst(shift)); }; auto insertBit = [&](Value desc, Value val, int startBit) { - return b.create(ti64, desc, shiftLeft(val, startBit)); + return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit)); }; int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); @@ -1086,7 +1089,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering Value baseAddr = getStridedElementPtr( rewriter, op->getLoc(), cast(op.getTensor().getType()), adaptor.getTensor(), {}); - Value basePtr = b.create(ti64, baseAddr); + Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr); // Just use 14 bits for base address Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); @@ -1118,8 +1121,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering }; static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { - return b.create(b.getIntegerType(64), - b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, b.getIntegerType(64), + b.getI32IntegerAttr(index)); } /// Returns a Value that holds data type enum that is expected by CUDA driver. @@ -1182,12 +1185,12 @@ struct NVGPUTmaCreateDescriptorOpLowering auto promotedOperands = getTypeConverter()->promoteOperands( b.getLoc(), op->getOperands(), adaptor.getOperands(), b); - Value boxArrayPtr = b.create(llvmPointerType, llvmInt64Type, - makeI64Const(b, 5)); + Value boxArrayPtr = LLVM::AllocaOp::create( + b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5)); for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { - Value gep = b.create(llvmPointerType, llvmPointerType, - boxArrayPtr, makeI64Const(b, index)); - b.create(value, gep); + Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType, + boxArrayPtr, makeI64Const(b, index)); + LLVM::StoreOp::create(b, value, gep); } nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); @@ -1337,7 +1340,7 @@ struct NVGPUWarpgroupMmaOpLowering /// Basic function to generate Add Value makeAdd(Value lhs, Value rhs) { - return b.create(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. @@ -1430,29 +1433,30 @@ struct NVGPUWarpgroupMmaOpLowering auto overflow = NVVM::MMAIntOverflowAttr::get( op->getContext(), NVVM::MMAIntOverflow::wrapped); - return b.create( - matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, - itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, + return NVVM::WgmmaMmaAsyncOp::create( + b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape, + itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); } /// Generates multiple wgmma instructions to complete the given GEMM shape Value generateWgmmaGroup() { Value wgmmaResult = - b.create(adaptor.getMatrixC().getType()); + LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType()); // Perform GEMM SmallVector wgmmaResults; for (int i = 0; i < iterationM; ++i) { - Value matrixC = b.create(adaptor.getMatrixC(), i); + Value matrixC = + LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i); for (int j = 0; j < iterationN; ++j) for (int k = 0; k < iterationK; ++k) matrixC = generateWgmma(i, j, k, matrixC); wgmmaResults.push_back(matrixC); } for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) { - wgmmaResult = b.create(wgmmaResult.getType(), - wgmmaResult, matrix, idx); + wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(), + wgmmaResult, matrix, idx); } return wgmmaResult; } @@ -1486,10 +1490,10 @@ struct NVGPUWarpgroupMmaOpLowering /// (WgmmaGroupSyncAlignedOp) for group synchronization /// (WgmmaWaitGroupSyncOp) after the instructions. Value generateWarpgroupMma() { - b.create(); + NVVM::WgmmaFenceAlignedOp::create(b); Value wgmmaResult = generateWgmmaGroup(); - b.create(); - b.create(op.getWaitGroup()); + NVVM::WgmmaGroupSyncAlignedOp::create(b); + NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup()); return wgmmaResult; } }; @@ -1557,7 +1561,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering Type i32 = b.getI32Type(); auto makeConst = [&](int32_t index) -> Value { - return b.create(i32, b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index)); }; Value c1 = makeConst(1); Value c2 = makeConst(2); @@ -1567,29 +1571,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering Value warpSize = makeConst(kWarpSize); auto makeMul = [&](Value lhs, Value rhs) -> Value { - return b.create(lhs.getType(), lhs, rhs); + return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs); }; auto makeAdd = [&](Value lhs, Value rhs) -> Value { - return b.create(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, TypedValue<::mlir::MemRefType> memref) { Type it = b.getIndexType(); - Value idx = b.create(it, x); - Value idy0 = b.create(it, y); - Value idy1 = b.create(it, makeAdd(y, c1)); - Value d0 = b.create(wgmmaResult, i); - Value d1 = b.create(wgmmaResult, i + 1); - b.create(d0, memref, ValueRange{idx, idy0}); - b.create(d1, memref, ValueRange{idx, idy1}); + Value idx = arith::IndexCastOp::create(b, it, x); + Value idy0 = arith::IndexCastOp::create(b, it, y); + Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1)); + Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i); + Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1); + memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0}); + memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1}); }; - Value tidx = b.create(i32); - Value laneId = b.create(i32, tidx, warpSize); - Value warpId = b.create(i32, tidx, warpSize); - Value lane4Id = b.create(i32, laneId, c4); - Value lane4modId = b.create(i32, laneId, c4); + Value tidx = NVVM::ThreadIdXOp::create(b, i32); + Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize); + Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize); + Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4); + Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4); Value tj = makeMul(lane4modId, c2); Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); @@ -1626,7 +1630,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering auto stype = cast(matriDValue.getType()); for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { auto structType = cast(matrixD); - Value innerStructValue = b.create(matriDValue, idx); + Value innerStructValue = + LLVM::ExtractValueOp::create(b, matriDValue, idx); storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); offset += structType.getBody().size(); } @@ -1648,23 +1653,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering Type elemType = cast(packStructType.getBody().front()) .getBody() .front(); - Value zero = b.create(elemType, b.getZeroAttr(elemType)); - Value packStruct = b.create(packStructType); + Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType)); + Value packStruct = LLVM::PoisonOp::create(b, packStructType); SmallVector innerStructs; // Unpack the structs and set all values to zero for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { auto structType = cast(s); - Value structValue = b.create(packStruct, idx); + Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx); for (unsigned i = 0; i < structType.getBody().size(); ++i) { - structValue = b.create( - structType, structValue, zero, ArrayRef({i})); + structValue = LLVM::InsertValueOp::create(b, structType, structValue, + zero, ArrayRef({i})); } innerStructs.push_back(structValue); } // Pack the inner structs into a single struct for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { - packStruct = b.create(packStruct.getType(), - packStruct, matrix, idx); + packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(), + packStruct, matrix, idx); } rewriter.replaceOp(op, packStruct); return success(); @@ -1681,7 +1686,7 @@ struct NVGPUTmaFenceOpLowering ImplicitLocOpBuilder b(op->getLoc(), rewriter); auto i32Ty = b.getI32Type(); Value tensormapSize = - b.create(i32Ty, rewriter.getI32IntegerAttr(128)); + LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128)); auto memscope = NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS); @@ -1716,13 +1721,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern { VectorType inTy = op.getIn().getType(); // apply rcp.approx.ftz.f on each element in vector. auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { - Value ret1DVec = b.create(llvm1DVectorTy); + Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy); int numElems = llvm::cast(llvm1DVectorTy).getNumElements(); for (int i = 0; i < numElems; i++) { - Value idx = b.create(i64Ty, b.getI64IntegerAttr(i)); - Value elem = b.create(inVec, idx); - Value dst = b.create(f32Ty, elem); - ret1DVec = b.create(ret1DVec, dst, idx); + Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i)); + Value elem = LLVM::ExtractElementOp::create(b, inVec, idx); + Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem); + ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx); } return ret1DVec; }; diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp index 479725aae8afd..a6125fc07dcd5 100644 --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -39,7 +39,7 @@ class ExpandIfCondition : public OpRewritePattern { IntegerAttr constAttr; if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) { - auto ifOp = rewriter.create(op.getLoc(), TypeRange(), + auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), TypeRange(), op.getIfCond(), false); rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener()); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 77a2708653576..836a193b7fb4a 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -85,7 +85,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern { } // Create new operation. - auto newOp = rewriter.create(op.getLoc(), resTypes, convertedOperands, + auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands, convertedAttrs); // Translate regions. diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 7d20109b3db59..75cd35c0d7876 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -196,7 +196,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion, // finalize. if (isa(node)) { builder.setInsertionPointToEnd(block); - builder.create(matcherFunc.getLoc()); + pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc()); return block; } @@ -272,7 +272,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { auto *operationPos = cast(pos); if (operationPos->isOperandDefiningOp()) // Standard (downward) traversal which directly follows the defining op. - value = builder.create( + value = pdl_interp::GetDefiningOpOp::create(builder, loc, builder.getType(), parentVal); else // A passthrough operation position. @@ -287,23 +287,23 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { // requested to use a representative value (e.g., upward traversal). if (isa(parentVal.getType()) && usersPos->useRepresentative()) - value = builder.create(loc, parentVal, 0); + value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0); else value = parentVal; // The second operation retrieves the users. - value = builder.create(loc, value); + value = pdl_interp::GetUsersOp::create(builder, loc, value); break; } case Predicates::ForEachPos: { assert(!failureBlockStack.empty() && "expected valid failure block"); - auto foreach = builder.create( + auto foreach = pdl_interp::ForEachOp::create(builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); value = foreach.getLoopVariable(); // Create the continuation block. Block *continueBlock = builder.createBlock(&foreach.getRegion()); - builder.create(loc); + pdl_interp::ContinueOp::create(builder, loc); failureBlockStack.push_back(continueBlock); currentBlock = &foreach.getRegion().front(); @@ -311,7 +311,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { } case Predicates::OperandPos: { auto *operandPos = cast(pos); - value = builder.create( + value = pdl_interp::GetOperandOp::create(builder, loc, builder.getType(), parentVal, operandPos->getOperandNumber()); break; @@ -319,28 +319,28 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { case Predicates::OperandGroupPos: { auto *operandPos = cast(pos); Type valueTy = builder.getType(); - value = builder.create( + value = pdl_interp::GetOperandsOp::create(builder, loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, parentVal, operandPos->getOperandGroupNumber()); break; } case Predicates::AttributePos: { auto *attrPos = cast(pos); - value = builder.create( + value = pdl_interp::GetAttributeOp::create(builder, loc, builder.getType(), parentVal, attrPos->getName().strref()); break; } case Predicates::TypePos: { if (isa(parentVal.getType())) - value = builder.create(loc, parentVal); + value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal); else - value = builder.create(loc, parentVal); + value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal); break; } case Predicates::ResultPos: { auto *resPos = cast(pos); - value = builder.create( + value = pdl_interp::GetResultOp::create(builder, loc, builder.getType(), parentVal, resPos->getResultNumber()); break; @@ -348,7 +348,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { case Predicates::ResultGroupPos: { auto *resPos = cast(pos); Type valueTy = builder.getType(); - value = builder.create( + value = pdl_interp::GetResultsOp::create(builder, loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, parentVal, resPos->getResultGroupNumber()); break; @@ -356,16 +356,16 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { case Predicates::AttributeLiteralPos: { auto *attrPos = cast(pos); value = - builder.create(loc, attrPos->getValue()); + pdl_interp::CreateAttributeOp::create(builder, loc, attrPos->getValue()); break; } case Predicates::TypeLiteralPos: { auto *typePos = cast(pos); Attribute rawTypeAttr = typePos->getValue(); if (TypeAttr typeAttr = dyn_cast(rawTypeAttr)) - value = builder.create(loc, typeAttr); + value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr); else - value = builder.create( + value = pdl_interp::CreateTypesOp::create(builder, loc, cast(rawTypeAttr)); break; } @@ -413,54 +413,54 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, Predicates::Kind kind = question->getKind(); switch (kind) { case Predicates::IsNotNullQuestion: - builder.create(loc, val, success, failure); + pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure); break; case Predicates::OperationNameQuestion: { auto *opNameAnswer = cast(answer); - builder.create( + pdl_interp::CheckOperationNameOp::create(builder, loc, val, opNameAnswer->getValue().getStringRef(), success, failure); break; } case Predicates::TypeQuestion: { auto *ans = cast(answer); if (isa(val.getType())) - builder.create( + pdl_interp::CheckTypesOp::create(builder, loc, val, llvm::cast(ans->getValue()), success, failure); else - builder.create( + pdl_interp::CheckTypeOp::create(builder, loc, val, llvm::cast(ans->getValue()), success, failure); break; } case Predicates::AttributeQuestion: { auto *ans = cast(answer); - builder.create(loc, val, ans->getValue(), + pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(), success, failure); break; } case Predicates::OperandCountAtLeastQuestion: case Predicates::OperandCountQuestion: - builder.create( + pdl_interp::CheckOperandCountOp::create(builder, loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, success, failure); break; case Predicates::ResultCountAtLeastQuestion: case Predicates::ResultCountQuestion: - builder.create( + pdl_interp::CheckResultCountOp::create(builder, loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, success, failure); break; case Predicates::EqualToQuestion: { bool trueAnswer = isa(answer); - builder.create(loc, val, args.front(), + pdl_interp::AreEqualOp::create(builder, loc, val, args.front(), trueAnswer ? success : failure, trueAnswer ? failure : success); break; } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - auto applyConstraintOp = builder.create( + auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, cstQuestion->getIsNegated(), success, failure); @@ -487,7 +487,7 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, blocks.push_back(it.second); values.push_back(cast(it.first)->getValue()); } - builder.create(val.getLoc(), val, values, defaultDest, blocks); + OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks); } void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, @@ -536,11 +536,11 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, unsigned ans = cast(child.first)->getValue(); switch (kind) { case Predicates::OperandCountAtLeastQuestion: - builder.create( + pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); break; case Predicates::ResultCountAtLeastQuestion: - builder.create( + pdl_interp::CheckResultCountOp::create(builder, loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); break; default: @@ -619,7 +619,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); - auto matchOp = builder.create( + auto matchOp = pdl_interp::RecordMatchOp::create(builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), failureBlockStack.back()); @@ -632,7 +632,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { SymbolRefAttr PatternLowering::generateRewriter( pdl::PatternOp pattern, SmallVectorImpl &usedMatchValues) { builder.setInsertionPointToEnd(rewriterModule.getBody()); - auto rewriterFunc = builder.create( + auto rewriterFunc = pdl_interp::FuncOp::create(builder, pattern.getLoc(), "pdl_generated_rewriter", builder.getFunctionType({}, {})); rewriterSymbolTable.insert(rewriterFunc); @@ -651,17 +651,17 @@ SymbolRefAttr PatternLowering::generateRewriter( Operation *oldOp = oldValue.getDefiningOp(); if (pdl::AttributeOp attrOp = dyn_cast(oldOp)) { if (Attribute value = attrOp.getValueAttr()) { - return newValue = builder.create( + return newValue = pdl_interp::CreateAttributeOp::create(builder, attrOp.getLoc(), value); } } else if (pdl::TypeOp typeOp = dyn_cast(oldOp)) { if (TypeAttr type = typeOp.getConstantTypeAttr()) { - return newValue = builder.create( + return newValue = pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), type); } } else if (pdl::TypesOp typeOp = dyn_cast(oldOp)) { if (ArrayAttr type = typeOp.getConstantTypesAttr()) { - return newValue = builder.create( + return newValue = pdl_interp::CreateTypesOp::create(builder, typeOp.getLoc(), typeOp.getType(), type); } } @@ -684,7 +684,7 @@ SymbolRefAttr PatternLowering::generateRewriter( auto mappedArgs = llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); - builder.create( + pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); } else { // Otherwise this is a dag rewriter defined using PDL operations. @@ -703,7 +703,7 @@ SymbolRefAttr PatternLowering::generateRewriter( llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), /*results=*/{})); - builder.create(rewriter.getLoc()); + pdl_interp::FinalizeOp::create(builder, rewriter.getLoc()); return SymbolRefAttr::get( builder.getContext(), pdl_interp::PDLInterpDialect::getRewriterModuleName(), @@ -716,7 +716,7 @@ void PatternLowering::generateRewriter( SmallVector arguments; for (Value argument : rewriteOp.getArgs()) arguments.push_back(mapRewriteValue(argument)); - auto interpOp = builder.create( + auto interpOp = pdl_interp::ApplyRewriteOp::create(builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), arguments); for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) @@ -726,7 +726,7 @@ void PatternLowering::generateRewriter( void PatternLowering::generateRewriter( pdl::AttributeOp attrOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - Value newAttr = builder.create( + Value newAttr = pdl_interp::CreateAttributeOp::create(builder, attrOp.getLoc(), attrOp.getValueAttr()); rewriteValues[attrOp] = newAttr; } @@ -734,7 +734,7 @@ void PatternLowering::generateRewriter( void PatternLowering::generateRewriter( pdl::EraseOp eraseOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - builder.create(eraseOp.getLoc(), + pdl_interp::EraseOp::create(builder, eraseOp.getLoc(), mapRewriteValue(eraseOp.getOpValue())); } @@ -756,7 +756,7 @@ void PatternLowering::generateRewriter( // Create the new operation. Location loc = operationOp.getLoc(); - Value createdOp = builder.create( + Value createdOp = pdl_interp::CreateOperationOp::create(builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, attributes, operationOp.getAttributeValueNames()); rewriteValues[operationOp.getOp()] = createdOp; @@ -768,8 +768,8 @@ void PatternLowering::generateRewriter( if (resultTys.size() == 1 && isa(resultTys[0].getType())) { Value &type = rewriteValues[resultTys[0]]; if (!type) { - auto results = builder.create(loc, createdOp); - type = builder.create(loc, results); + auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp); + type = pdl_interp::GetValueTypeOp::create(builder, loc, results); } return; } @@ -789,12 +789,12 @@ void PatternLowering::generateRewriter( // groups because the exact index of the result is not statically known. Value resultVal; if (seenVariableLength) - resultVal = builder.create( + resultVal = pdl_interp::GetResultsOp::create(builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); else - resultVal = builder.create( + resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy, createdOp, it.index()); - type = builder.create(loc, resultVal); + type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal); } } @@ -804,7 +804,7 @@ void PatternLowering::generateRewriter( SmallVector replOperands; for (Value operand : rangeOp.getArguments()) replOperands.push_back(mapRewriteValue(operand)); - rewriteValues[rangeOp] = builder.create( + rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(builder, rangeOp.getLoc(), rangeOp.getType(), replOperands); } @@ -820,7 +820,7 @@ void PatternLowering::generateRewriter( // Don't use replace if we know the replaced operation has no results. auto opOp = replaceOp.getOpValue().getDefiningOp(); if (!opOp || !opOp.getTypeValues().empty()) { - replOperands.push_back(builder.create( + replOperands.push_back(pdl_interp::GetResultsOp::create(builder, replOp.getLoc(), mapRewriteValue(replOp))); } } else { @@ -830,12 +830,12 @@ void PatternLowering::generateRewriter( // If there are no replacement values, just create an erase instead. if (replOperands.empty()) { - builder.create( + pdl_interp::EraseOp::create(builder, replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); return; } - builder.create(replaceOp.getLoc(), + pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()), replOperands); } @@ -843,7 +843,7 @@ void PatternLowering::generateRewriter( void PatternLowering::generateRewriter( pdl::ResultOp resultOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - rewriteValues[resultOp] = builder.create( + rewriteValues[resultOp] = pdl_interp::GetResultOp::create(builder, resultOp.getLoc(), builder.getType(), mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); } @@ -851,7 +851,7 @@ void PatternLowering::generateRewriter( void PatternLowering::generateRewriter( pdl::ResultsOp resultOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { - rewriteValues[resultOp] = builder.create( + rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(builder, resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); } @@ -863,7 +863,7 @@ void PatternLowering::generateRewriter( // type. if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { rewriteValues[typeOp] = - builder.create(typeOp.getLoc(), typeAttr); + pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr); } } @@ -873,7 +873,7 @@ void PatternLowering::generateRewriter( // If the type isn't constant, the users (e.g. OperationOp) will resolve this // type. if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { - rewriteValues[typeOp] = builder.create( + rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(builder, typeOp.getLoc(), typeOp.getType(), typeAttr); } } @@ -939,9 +939,9 @@ void PatternLowering::generateOperationResultTypeRewriter( !replacedOp->isBeforeInBlock(op)) continue; - Value replacedOpResults = builder.create( + Value replacedOpResults = pdl_interp::GetResultsOp::create(builder, replacedOp->getLoc(), mapRewriteValue(replOpVal)); - types.push_back(builder.create( + types.push_back(pdl_interp::GetValueTypeOp::create(builder, replacedOp->getLoc(), replacedOpResults)); return; } @@ -985,7 +985,7 @@ void PDLToPDLInterpPass::runOnOperation() { // Create the main matcher function This function contains all of the match // related functionality from patterns in the module. OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); - auto matcherFunc = builder.create( + auto matcherFunc = pdl_interp::FuncOp::create(builder, module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), builder.getFunctionType(builder.getType(), /*results=*/{}), @@ -993,7 +993,7 @@ void PDLToPDLInterpPass::runOnOperation() { // Create a nested module to hold the functions invoked for rewriting the IR // after a successful match. - ModuleOp rewriterModule = builder.create( + ModuleOp rewriterModule = ModuleOp::create(builder, module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); // Generate the code for the patterns within the module. diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index ac6dcf2513e99..ffef20ce67818 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -340,14 +340,14 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.getStep(); - auto stepped = rewriter.create(loc, iv, step).getResult(); + auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult(); if (!stepped) return failure(); SmallVector loopCarried; loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); - rewriter.create(loc, conditionBlock, loopCarried); + cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. @@ -362,14 +362,14 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, SmallVector destOperands; destOperands.push_back(lowerBound); llvm::append_range(destOperands, forOp.getInitArgs()); - rewriter.create(loc, conditionBlock, destOperands); + cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = rewriter.create( + auto comparison = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound); - auto condBranchOp = rewriter.create( + auto condBranchOp = cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); @@ -404,7 +404,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, continueBlock = rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), SmallVector(ifOp.getNumResults(), loc)); - rewriter.create(loc, remainingOpsBlock); + cf::BranchOp::create(rewriter, loc, remainingOpsBlock); } // Move blocks from the "then" region to the region containing 'scf.if', @@ -414,7 +414,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, Operation *thenTerminator = thenRegion.back().getTerminator(); ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, continueBlock, thenTerminatorOperands); + cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands); rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); @@ -428,13 +428,13 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, Operation *elseTerminator = elseRegion.back().getTerminator(); ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, continueBlock, elseTerminatorOperands); + cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands); rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, ifOp.getCondition(), thenBlock, + cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock, /*trueArgs=*/ArrayRef(), elseBlock, /*falseArgs=*/ArrayRef()); @@ -454,13 +454,13 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, auto ®ion = op.getRegion(); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, ®ion.front()); + cf::BranchOp::create(rewriter, loc, ®ion.front()); for (Block &block : region) { if (auto terminator = dyn_cast(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); - rewriter.create(loc, remainingOpsBlock, terminatorOperands); + cf::BranchOp::create(rewriter, loc, remainingOpsBlock, terminatorOperands); rewriter.eraseOp(terminator); } } @@ -498,7 +498,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, for (auto [iv, lower, upper, step] : llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep())) { - ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); + ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs); ivs.push_back(forOp.getInductionVar()); auto iterRange = forOp.getRegionIterArgs(); iterArgs.assign(iterRange.begin(), iterRange.end()); @@ -512,7 +512,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, // A loop is constructed with an empty "yield" terminator if there are // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create(loc, forOp.getResults()); + scf::YieldOp::create(rewriter, loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); @@ -544,7 +544,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, // has been already created in loop construction). if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create(loc, yieldOperands); + scf::YieldOp::create(rewriter, loc, yieldOperands); } rewriter.replaceOp(parallelOp, loopResults); @@ -570,7 +570,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, before, whileOp.getInits()); + cf::BranchOp::create(rewriter, loc, before, whileOp.getInits()); // Replace terminators with branches. Assuming bodies are SESE, which holds // given only the patterns from this file, we only need to look at the last @@ -620,7 +620,7 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(whileOp.getLoc(), before, whileOp.getInits()); + cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits()); // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(before); @@ -688,10 +688,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, SmallVector caseOperands(caseSuccessors.size(), {}); // Cast switch index to integer case value. - Value caseValue = rewriter.create( + Value caseValue = arith::IndexCastOp::create(rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg()); - rewriter.create( + cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock, ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); rewriter.replaceOp(op, continueBlock->getArguments()); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index dcb48529a74e6..37134fe9d992a 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -91,7 +91,7 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, Type varType = emitc::LValueType::get(resultType); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = - rewriter.create(loc, varType, noInit); + emitc::VariableOp::create(rewriter, loc, varType, noInit); resultVariables.push_back(var); } @@ -103,14 +103,14 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, static void assignValues(ValueRange values, ValueRange variables, ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) - rewriter.create(loc, var, value); + emitc::AssignOp::create(rewriter, loc, var, value); } SmallVector loadValues(const SmallVector &variables, PatternRewriter &rewriter, Location loc) { return llvm::map_to_vector<>(variables, [&](Value var) { Type type = cast(var.getType()).getValueType(); - return rewriter.create(loc, type, var).getResult(); + return emitc::LoadOp::create(rewriter, loc, type, var).getResult(); }); } @@ -129,7 +129,7 @@ static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, assignValues(yieldOperands, resultVariables, rewriter, loc); - rewriter.create(loc); + emitc::YieldOp::create(rewriter, loc); rewriter.eraseOp(yield); return success(); @@ -164,7 +164,7 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); - emitc::ForOp loweredFor = rewriter.create( + emitc::ForOp loweredFor = emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -257,7 +257,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create(loc, adaptor.getCondition(), false, false); + emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); auto result = lowerRegion(thenRegion, loweredThenRegion); @@ -304,7 +304,7 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite( "create variables for results failed"); } - auto loweredSwitch = rewriter.create( + auto loweredSwitch = emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); // Lowering all case regions. diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 844e66e927c4d..bfdb5a688b286 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -84,7 +84,7 @@ static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { - return builder.create(forOp.getLoc(), + return arith::ConstantIndexOp::create(builder, forOp.getLoc(), forOp.getStepAsInt()); } @@ -190,12 +190,12 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) { return std::nullopt; } - Value range = builder.create(currentLoop.getLoc(), + Value range = arith::SubIOp::create(builder, currentLoop.getLoc(), upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); if (getConstantIntValue(step) != static_cast(1)) range = - builder.create(currentLoop.getLoc(), range, step); + arith::CeilDivSIOp::create(builder, currentLoop.getLoc(), range, step); dims.push_back(range); lbs.push_back(lowerBound); @@ -221,7 +221,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, // no loop mapped to a specific dimension, use constant "1" as its size. Value constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) + ? arith::ConstantIndexOp::create(builder, rootForOp.getLoc(), 1) : nullptr; Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne; Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; @@ -232,7 +232,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, // Create a launch op and move the body region of the innermost loop to the // launch op. - auto launchOp = builder.create( + auto launchOp = gpu::LaunchOp::create(builder, rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ); @@ -244,7 +244,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, Location terminatorLoc = terminator.getLoc(); terminator.erase(); builder.setInsertionPointToEnd(innermostForOp.getBody()); - builder.create(terminatorLoc, TypeRange()); + gpu::TerminatorOp::create(builder, terminatorLoc, TypeRange()); launchOp.getBody().front().getOperations().splice( launchOp.getBody().front().begin(), innermostForOp.getBody()->getOperations()); @@ -263,10 +263,10 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; if (getConstantIntValue(step) != static_cast(1)) - id = builder.create(rootForOp.getLoc(), step, id); + id = arith::MulIOp::create(builder, rootForOp.getLoc(), step, id); Value ivReplacement = - builder.create(rootForOp.getLoc(), *lbArgumentIt, id); + arith::AddIOp::create(builder, rootForOp.getLoc(), *lbArgumentIt, id); en.value().replaceAllUsesWith(ivReplacement); std::advance(lbArgumentIt, 1); std::advance(stepArgumentIt, 1); @@ -319,7 +319,7 @@ static Value deriveStaticUpperBound(Value upperBound, if (auto minOp = upperBound.getDefiningOp()) { for (const AffineExpr &result : minOp.getMap().getResults()) { if (auto constExpr = dyn_cast(result)) { - return rewriter.create(minOp.getLoc(), + return arith::ConstantIndexOp::create(rewriter, minOp.getLoc(), constExpr.getValue()); } } @@ -344,7 +344,7 @@ static Value deriveStaticUpperBound(Value upperBound, if ((lhs.value() < 0) != (rhs.value() < 0)) return {}; - return rewriter.create( + return arith::ConstantIndexOp::create(rewriter, multiplyOp.getLoc(), lhs.value() * rhs.value()); } } @@ -422,7 +422,7 @@ static LogicalResult processParallelLoop( if (launchIndependent(val)) return val; if (auto constOp = val.getDefiningOp()) - return rewriter.create(constOp.getLoc(), + return arith::ConstantOp::create(rewriter, constOp.getLoc(), constOp.getValue()); return {}; }; @@ -453,7 +453,7 @@ static LogicalResult processParallelLoop( 1, 2, rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1)); - newIndex = rewriter.create( + newIndex = AffineApplyOp::create(rewriter, loc, annotation.getMap().compose(lowerAndStep), ValueRange{operand, ensureLaunchIndependent(step), ensureLaunchIndependent(lowerBound)}); @@ -498,7 +498,7 @@ static LogicalResult processParallelLoop( 1, 2, ((rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0)) .ceilDiv(rewriter.getAffineSymbolExpr(1)))); - Value launchBound = rewriter.create( + Value launchBound = AffineApplyOp::create(rewriter, loc, annotation.getBound().compose(stepMap), ValueRange{ ensureLaunchIndependent( @@ -517,10 +517,10 @@ static LogicalResult processParallelLoop( if (!boundIsPrecise) { // We are using an approximation, create a surrounding conditional. Value originalBound = std::get<3>(config); - arith::CmpIOp pred = rewriter.create( + arith::CmpIOp pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, newIndex, cloningMap.lookupOrDefault(originalBound)); - scf::IfOp ifOp = rewriter.create(loc, pred, false); + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, pred, false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Put a sentinel into the worklist so we know when to pop out of the // if body again. We use the launchOp here, as that cannot be part of @@ -530,7 +530,7 @@ static LogicalResult processParallelLoop( } } else { // Create a sequential for loop. - auto loopOp = rewriter.create( + auto loopOp = scf::ForOp::create(rewriter, loc, cloningMap.lookupOrDefault(lowerBound), cloningMap.lookupOrDefault(upperBound), cloningMap.lookupOrDefault(step)); @@ -608,12 +608,12 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, // sizes. Those will be refined later as we discover them from mappings. Location loc = parallelOp.getLoc(); Value constantOne = - rewriter.create(parallelOp.getLoc(), 1); - gpu::LaunchOp launchOp = rewriter.create( + arith::ConstantIndexOp::create(rewriter, parallelOp.getLoc(), 1); + gpu::LaunchOp launchOp = gpu::LaunchOp::create(rewriter, parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne, constantOne, constantOne); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); - rewriter.create(loc); + gpu::TerminatorOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&launchOp.getBody().front()); IRMapping cloningMap; @@ -667,7 +667,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, if (externalValues.size()) return failure(); // Replace by gpu.all_reduce. - auto gpuRedOp = rewriter.create(loc, newValue); + auto gpuRedOp = gpu::AllReduceOp::create(rewriter, loc, newValue); cloningMap.map(parentLoop->getResult(0), gpuRedOp.getResult()); // Copy region. rewriter.inlineRegionBefore(reduceOp.getRegion(0), gpuRedOp.getRegion(), diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 584ac2f11b670..7b4c37d44824d 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -187,7 +187,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) { OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperands()[reductionIndex].getType(); - auto decl = builder.create(reduce.getLoc(), + auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(), "__scf_reduction", type); symbolTable.insert(decl); @@ -196,8 +196,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, {reduce.getOperands()[reductionIndex].getLoc()}); builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); Value init = - builder.create(reduce.getLoc(), type, initValue); - builder.create(reduce.getLoc(), init); + LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue); + omp::YieldOp::create(builder, reduce.getLoc(), init); Operation *terminator = &reduce.getReductions()[reductionIndex].front().back(); @@ -227,12 +227,12 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, {reduceOperandLoc, reduceOperandLoc}); Block *atomicBlock = &decl.getAtomicReductionRegion().back(); builder.setInsertionPointToEnd(atomicBlock); - Value loaded = builder.create(reduce.getLoc(), decl.getType(), + Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(), atomicBlock->getArgument(1)); - builder.create(reduce.getLoc(), atomicKind, + LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind, atomicBlock->getArgument(0), loaded, LLVM::AtomicOrdering::monotonic); - builder.create(reduce.getLoc(), ArrayRef()); + omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef()); return decl; } @@ -380,7 +380,7 @@ struct ParallelOpLowering : public OpRewritePattern { // Allocate reduction variables. Make sure the we don't overflow the stack // with local `alloca`s by saving and restoring the stack pointer. Location loc = parallelOp.getLoc(); - Value one = rewriter.create( + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); @@ -391,8 +391,8 @@ struct ParallelOpLowering : public OpRewritePattern { "cannot create a reduction variable if the type is not an LLVM " "pointer element"); Value storage = - rewriter.create(loc, ptrType, init.getType(), one, 0); - rewriter.create(loc, init, storage); + LLVM::AllocaOp::create(rewriter, loc, ptrType, init.getType(), one, 0); + LLVM::StoreOp::create(rewriter, loc, init, storage); reductionVariables.push_back(storage); } @@ -411,7 +411,7 @@ struct ParallelOpLowering : public OpRewritePattern { assert(redRegion.hasOneBlock() && "expect reduction region to have one block"); Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc); - Value pvtRedVal = rewriter.create(reduce.getLoc(), + Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(), rD.getType(), pvtRedVar); // Make a copy of the reduction combiner region in the body mlir::OpBuilder builder(rewriter.getContext()); @@ -427,7 +427,7 @@ struct ParallelOpLowering : public OpRewritePattern { assert(yieldOp && yieldOp.getResults().size() == 1 && "expect YieldOp in reduction region to return one result"); Value redVal = yieldOp.getResults()[0]; - rewriter.create(loc, redVal, pvtRedVar); + LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar); rewriter.eraseOp(yieldOp); break; } @@ -437,11 +437,11 @@ struct ParallelOpLowering : public OpRewritePattern { Value numThreadsVar; if (numThreads > 0) { - numThreadsVar = rewriter.create( + numThreadsVar = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); } // Create the parallel wrapper. - auto ompParallel = rewriter.create( + auto ompParallel = omp::ParallelOp::create(rewriter, loc, /* allocate_vars = */ llvm::SmallVector{}, /* allocator_vars = */ llvm::SmallVector{}, @@ -464,7 +464,7 @@ struct ParallelOpLowering : public OpRewritePattern { { OpBuilder::InsertionGuard allocaGuard(rewriter); // Create worksharing loop wrapper. - auto wsloopOp = rewriter.create(parallelOp.getLoc()); + auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc()); if (!reductionVariables.empty()) { wsloopOp.setReductionSymsAttr( ArrayAttr::get(rewriter.getContext(), reductionSyms)); @@ -476,7 +476,7 @@ struct ParallelOpLowering : public OpRewritePattern { wsloopOp.setReductionByref( DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef)); } - rewriter.create(loc); // omp.parallel terminator. + omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator. // The wrapper's entry block arguments will define the reduction // variables. @@ -490,7 +490,7 @@ struct ParallelOpLowering : public OpRewritePattern { parallelOp.getLoc())); // Create loop nest and populate region with contents of scf.parallel. - auto loopOp = rewriter.create( + auto loopOp = omp::LoopNestOp::create(rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep()); @@ -511,13 +511,13 @@ struct ParallelOpLowering : public OpRewritePattern { rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin()); rewriter.setInsertionPointToStart(&loopOpEntryBlock); - auto scope = rewriter.create(parallelOp.getLoc(), + auto scope = memref::AllocaScopeOp::create(rewriter, parallelOp.getLoc(), TypeRange()); - rewriter.create(loc, ValueRange()); + omp::YieldOp::create(rewriter, loc, ValueRange()); Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); rewriter.mergeBlocks(ops, scopeBlock); rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); - rewriter.create(loc, ValueRange()); + memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange()); } } @@ -526,7 +526,7 @@ struct ParallelOpLowering : public OpRewritePattern { results.reserve(reductionVariables.size()); for (auto [variable, type] : llvm::zip(reductionVariables, parallelOp.getResultTypes())) { - Value res = rewriter.create(loc, type, variable); + Value res = LLVM::LoadOp::create(rewriter, loc, type, variable); results.push_back(res); } rewriter.replaceOp(parallelOp, results); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 78d13278fef53..b494fa98f1362 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, auto pointerType = spirv::PointerType::get(convertedType, spirv::StorageClass::Function); rewriter.setInsertionPoint(newOp); - auto alloc = rewriter.create( + auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType, spirv::StorageClass::Function, /*initializer=*/nullptr); allocas.push_back(alloc); rewriter.setInsertionPointAfter(newOp); - Value loadResult = rewriter.create(loc, alloc); + Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc); resultValue.push_back(loadResult); } rewriter.replaceOp(scfOp, resultValue); @@ -135,7 +135,7 @@ struct ForOpConversion final : SCFToSPIRVPattern { // a single back edge from the continue to header block, and a single exit // from header to merge. auto loc = forOp.getLoc(); - auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + auto loopOp = spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); OpBuilder::InsertionGuard guard(rewriter); @@ -172,15 +172,15 @@ struct ForOpConversion final : SCFToSPIRVPattern { args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); - rewriter.create(loc, header, args); + spirv::BranchOp::create(rewriter, loc, header, args); // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = rewriter.create( + auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); - rewriter.create( + spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); // Generate instructions to increment the step of the induction variable and @@ -189,9 +189,9 @@ struct ForOpConversion final : SCFToSPIRVPattern { rewriter.setInsertionPointToEnd(continueBlock); // Add the step to the induction variable and branch to the header. - Value updatedIndVar = rewriter.create( + Value updatedIndVar = spirv::IAddOp::create(rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep()); - rewriter.create(loc, header, updatedIndVar); + spirv::BranchOp::create(rewriter, loc, header, updatedIndVar); // Infer the return types from the init operands. Vector type may get // converted to CooperativeMatrix or to Vector type, to avoid having complex @@ -238,10 +238,10 @@ struct IfOpConversion : SCFToSPIRVPattern { // Create `spirv.selection` operation, selection header block and merge // block. auto selectionOp = - rewriter.create(loc, spirv::SelectionControl::None); + spirv::SelectionOp::create(rewriter, loc, spirv::SelectionControl::None); auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end()); - rewriter.create(loc); + spirv::MergeOp::create(rewriter, loc); OpBuilder::InsertionGuard guard(rewriter); auto *selectionHeaderBlock = @@ -251,7 +251,7 @@ struct IfOpConversion : SCFToSPIRVPattern { auto &thenRegion = ifOp.getThenRegion(); auto *thenBlock = &thenRegion.front(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(thenRegion, mergeBlock); auto *elseBlock = mergeBlock; @@ -261,13 +261,13 @@ struct IfOpConversion : SCFToSPIRVPattern { auto &elseRegion = ifOp.getElseRegion(); elseBlock = &elseRegion.front(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(elseRegion, mergeBlock); } // Create a `spirv.BranchConditional` operation for selection header block. rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create(loc, adaptor.getCondition(), + spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(), thenBlock, ArrayRef(), elseBlock, ArrayRef()); @@ -310,7 +310,7 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern { auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) - rewriter.create(loc, allocas[i], operands[i]); + spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]); if (isa(parent)) { // For loops we also need to update the branch jumping back to the // header. @@ -319,7 +319,7 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern { SmallVector args(br.getBlockArguments()); args.append(operands.begin(), operands.end()); rewriter.setInsertionPoint(br); - rewriter.create(terminatorOp.getLoc(), br.getTarget(), + spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(), args); rewriter.eraseOp(br); } @@ -340,7 +340,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern { matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = whileOp.getLoc(); - auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + auto loopOp = spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); Region &beforeRegion = whileOp.getBefore(); @@ -382,7 +382,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern { // Jump from the loop entry block to the loop header block. rewriter.setInsertionPointToEnd(&entryBlock); - rewriter.create(loc, &beforeBlock, adaptor.getInits()); + spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits()); auto condLoc = cond.getLoc(); @@ -403,18 +403,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern { // Create local variables before the scf.while op. rewriter.setInsertionPoint(loopOp); - auto alloc = rewriter.create( + auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType, spirv::StorageClass::Function, /*initializer=*/nullptr); // Load the final result values after the scf.while op. rewriter.setInsertionPointAfter(loopOp); - auto loadResult = rewriter.create(condLoc, alloc); + auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc); resultValues[i] = loadResult; // Store the current iteration's result value. rewriter.setInsertionPointToEnd(&beforeBlock); - rewriter.create(condLoc, alloc, res); + spirv::StoreOp::create(rewriter, condLoc, alloc, res); } rewriter.setInsertionPointToEnd(&beforeBlock); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index d7ae9f0e94fe8..934bb0d1c37fd 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -68,7 +68,7 @@ static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { /// Copies the given number of bytes from src to dst pointers. static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder) { - builder.create(loc, dst, src, size, /*isVolatile=*/false); + LLVM::MemcpyOp::create(builder, loc, dst, src, size, /*isVolatile=*/false); } /// Encodes the binding and descriptor set numbers into a new symbolic name. @@ -194,7 +194,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { if (!kernelFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - kernelFunc = rewriter.create( + kernelFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(), newKernelFuncName, LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), ArrayRef())); @@ -245,7 +245,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { if (!dstGlobal) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - dstGlobal = rewriter.create( + dstGlobal = LLVM::GlobalOp::create(rewriter, loc, dstGlobalType, /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), /*alignment=*/0); @@ -255,7 +255,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { // Copy the data from src operand pointer to dst global variable. Save // src, dst and size so that we can copy data back after emulating the // kernel call. - Value dst = rewriter.create( + Value dst = LLVM::AddressOfOp::create(rewriter, loc, typeConverter->convertType(spirvGlobal.getType()), dstGlobal.getSymName()); copy(loc, dst, src, sizeBytes, rewriter); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index cea9d1fdec809..4e8468078b9ad 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -94,12 +94,12 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (isa(srcType)) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, dstType, SplatElementsAttr::get(cast(srcType), minusOneIntegerAttribute(srcType, rewriter))); } - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } @@ -108,13 +108,13 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { if (auto vecType = dyn_cast(srcType)) { auto floatType = cast(vecType.getElementType()); - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } auto floatType = cast(srcType); - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, dstType, rewriter.getFloatAttr(floatType, value)); } @@ -134,13 +134,13 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) - return rewriter.create(loc, llvmType, value); + return LLVM::ZExtOp::create(rewriter, loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) - return rewriter.create(loc, llvmType, value); + return LLVM::TruncOp::create(rewriter, loc, llvmType, value); return value; } @@ -151,11 +151,11 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); - Value broadcasted = rewriter.create(loc, llvmVectorType); + Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { - auto index = rewriter.create( + auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); - broadcasted = rewriter.create( + broadcasted = LLVM::InsertElementOp::create(rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; @@ -217,7 +217,7 @@ static Type convertStructTypePacked(spirv::StructType type, /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } @@ -322,7 +322,7 @@ class AccessChainPattern : public SPIRVToLLVMConversion { auto llvmIndexType = getTypeConverter()->convertType(indexType); if (!llvmIndexType) return rewriter.notifyMatchFailure(op, "type conversion failed"); - Value zero = rewriter.create( + Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); @@ -375,20 +375,20 @@ class BitFieldInsertPattern // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create(loc, dstType, minusOne, count); - Value negated = rewriter.create(loc, dstType, + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value negated = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = - rewriter.create(loc, dstType, negated, offset); - Value mask = rewriter.create( + LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = - rewriter.create(loc, dstType, op.getBase(), mask); + LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask); Value insertShiftedByOffset = - rewriter.create(loc, dstType, op.getInsert(), offset); + LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset); rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, insertShiftedByOffset); return success(); @@ -470,23 +470,23 @@ class BitFieldSExtractPattern auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = isa(srcType) - ? rewriter.create( + ? LLVM::ConstantOp::create(rewriter, loc, dstType, SplatElementsAttr::get(cast(srcType), baseSize)) - : rewriter.create(loc, dstType, baseSize); + : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit // at Offset + Count - 1 is the most significant bit now. Value countPlusOffset = - rewriter.create(loc, dstType, count, offset); + LLVM::AddOp::create(rewriter, loc, dstType, count, offset); Value amountToShiftLeft = - rewriter.create(loc, dstType, size, countPlusOffset); - Value baseShiftedLeft = rewriter.create( + LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset); + Value baseShiftedLeft = LLVM::ShlOp::create(rewriter, loc, dstType, op.getBase(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = - rewriter.create(loc, dstType, offset, amountToShiftLeft); + LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft); rewriter.replaceOpWithNewOp(op, dstType, baseShiftedLeft, amountToShiftRight); return success(); @@ -516,13 +516,13 @@ class BitFieldUExtractPattern // Create a mask with bits set at [0, Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create(loc, dstType, minusOne, count); - Value mask = rewriter.create(loc, dstType, maskShiftedByCount, + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount, minusOne); // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = - rewriter.create(loc, dstType, op.getBase(), offset); + LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset); rewriter.replaceOpWithNewOp(op, dstType, shiftedBase, mask); return success(); } @@ -694,7 +694,7 @@ class ExecutionModePattern auto structType = LLVM::LLVMStructType::getLiteral(context, fields); // Create `llvm.mlir.global` with initializer region containing one block. - auto global = rewriter.create( + auto global = LLVM::GlobalOp::create(rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true, LLVM::Linkage::External, executionModeInfoName, Attribute(), /*alignment=*/0); @@ -704,8 +704,8 @@ class ExecutionModePattern // Initialize the struct and set the execution mode value. rewriter.setInsertionPointToStart(block); - Value structValue = rewriter.create(loc, structType); - Value executionMode = rewriter.create( + Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType); + Value executionMode = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, rewriter.getI32IntegerAttr( static_cast(executionModeAttr.getValue()))); @@ -715,11 +715,11 @@ class ExecutionModePattern // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { auto attr = values.getValue()[i]; - Value entry = rewriter.create(loc, llvmI32Type, attr); - structValue = rewriter.create( + Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr); + structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue, entry, ArrayRef({1, i})); } - rewriter.create(loc, ArrayRef({structValue})); + LLVM::ReturnOp::create(rewriter, loc, ArrayRef({structValue})); rewriter.eraseOp(op); return success(); } @@ -911,7 +911,7 @@ class InverseSqrtPattern Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value sqrt = rewriter.create(loc, dstType, op.getOperand()); + Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand()); rewriter.replaceOpWithNewOp(op, dstType, one, sqrt); return success(); } @@ -971,10 +971,10 @@ class NotPattern : public SPIRVToLLVMConversion { IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = isa(srcType) - ? rewriter.create( + ? LLVM::ConstantOp::create(rewriter, loc, dstType, SplatElementsAttr::get(cast(srcType), minusOne)) - : rewriter.create(loc, dstType, minusOne); + : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.getOperand(), mask); return success(); @@ -1032,7 +1032,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, return func; OpBuilder b(symbolTable->getRegion(0)); - func = b.create( + func = LLVM::LLVMFuncOp::create(b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); @@ -1045,7 +1045,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = builder.create(loc, func, args); + auto call = LLVM::CallOp::create(builder, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -1076,11 +1076,11 @@ class ControlBarrierPattern : public SPIRVToLLVMConversion { lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); Location loc = controlBarrierOp->getLoc(); - Value execution = rewriter.create( + Value execution = LLVM::ConstantOp::create(rewriter, loc, i32, static_cast(adaptor.getExecutionScope())); - Value memory = rewriter.create( + Value memory = LLVM::ConstantOp::create(rewriter, loc, i32, static_cast(adaptor.getMemoryScope())); - Value semantics = rewriter.create( + Value semantics = LLVM::ConstantOp::create(rewriter, loc, i32, static_cast(adaptor.getMemorySemantics())); auto call = createSPIRVBuiltinCall(loc, rewriter, func, @@ -1253,9 +1253,9 @@ class GroupReducePattern : public SPIRVToLLVMConversion { lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); Location loc = op.getLoc(); - Value scope = rewriter.create( + Value scope = LLVM::ConstantOp::create(rewriter, loc, i32Ty, static_cast(adaptor.getExecutionScope())); - Value groupOp = rewriter.create( + Value groupOp = LLVM::ConstantOp::create(rewriter, loc, i32Ty, static_cast(adaptor.getGroupOperation())); SmallVector operands{scope, groupOp}; operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); @@ -1366,7 +1366,7 @@ class LoopPattern : public SPIRVToLLVMConversion { return failure(); Block *headerBlock = loopOp.getHeaderBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, brOp.getBlockArguments(), headerBlock); + LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock); rewriter.eraseBlock(entryBlock); // Branch from merge block to end block. @@ -1374,7 +1374,7 @@ class LoopPattern : public SPIRVToLLVMConversion { Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create(loc, terminatorOperands, endBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock); rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); rewriter.replaceOp(loopOp, endBlock->getArguments()); @@ -1433,13 +1433,13 @@ class SelectionPattern : public SPIRVToLLVMConversion { Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create(loc, terminatorOperands, continueBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock); // Link current block to `true` and `false` blocks within the selection. Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, condBrOp.getCondition(), trueBlock, + LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock, condBrOp.getTrueTargetOperands(), falseBlock, condBrOp.getFalseTargetOperands()); @@ -1519,8 +1519,8 @@ class TanPattern : public SPIRVToLLVMConversion { return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); Location loc = tanOp.getLoc(); - Value sin = rewriter.create(loc, dstType, tanOp.getOperand()); - Value cos = rewriter.create(loc, dstType, tanOp.getOperand()); + Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); + Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); rewriter.replaceOpWithNewOp(tanOp, dstType, sin, cos); return success(); } @@ -1547,13 +1547,13 @@ class TanhPattern : public SPIRVToLLVMConversion { Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value multiplied = - rewriter.create(loc, dstType, two, tanhOp.getOperand()); - Value exponential = rewriter.create(loc, dstType, multiplied); + LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); + Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value numerator = - rewriter.create(loc, dstType, exponential, one); + LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); Value denominator = - rewriter.create(loc, dstType, exponential, one); + LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); rewriter.replaceOpWithNewOp(tanhOp, dstType, numerator, denominator); return success(); @@ -1592,8 +1592,8 @@ class VariablePattern : public SPIRVToLLVMConversion { if (!elementType) return rewriter.notifyMatchFailure(varOp, "type conversion failed"); Value allocated = - rewriter.create(loc, dstType, elementType, size); - rewriter.create(loc, adaptor.getInitializer(), allocated); + LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); } @@ -1654,7 +1654,7 @@ class FuncConversionPattern : public SPIRVToLLVMConversion { // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); - auto newFuncOp = rewriter.create(loc, name, llvmType); + auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); @@ -1708,7 +1708,7 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion { ConversionPatternRewriter &rewriter) const override { auto newModuleOp = - rewriter.create(spvModuleOp.getLoc(), spvModuleOp.getName()); + ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName()); rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder @@ -1749,7 +1749,7 @@ class VectorShufflePattern auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); - Value targetOp = rewriter.create(loc, dstType); + Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { if (!isa(componentsArray[i])) return op.emitError("unable to support non-constant component"); @@ -1765,15 +1765,15 @@ class VectorShufflePattern baseVector = vector2; } - Value dstIndex = rewriter.create( + Value dstIndex = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); - Value index = rewriter.create( + Value index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); - auto extractOp = rewriter.create( + auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType, baseVector, index); - targetOp = rewriter.create(loc, dstType, targetOp, + targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp, extractOp, dstIndex); } rewriter.replaceOp(op, targetOp); diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp index da9ad3dd67328..245e60b04ec31 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -32,7 +32,7 @@ class ConvertCstrRequireOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { - rewriter.create(op.getLoc(), op.getPred(), op.getMsgAttr()); + cf::AssertOp::create(rewriter, op.getLoc(), op.getPred(), op.getMsgAttr()); rewriter.replaceOpWithNewOp(op, true); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index bbe1490137bf8..58cffa6119d93 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -82,19 +82,19 @@ struct BroadcastOpConverter : public OpConversionPattern { // number of extent tensors and shifted offsets into them. Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, ValueRange rankDiffs, Value outputDimension) { - Value one = lb.create(1); + Value one = arith::ConstantIndexOp::create(lb, 1); Value broadcastedDim = one; for (auto tup : llvm::zip(extentTensors, rankDiffs)) { Value shape = std::get<0>(tup); Value rankDiff = std::get<1>(tup); - Value outOfBounds = lb.create(arith::CmpIPredicate::ult, + Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult, outputDimension, rankDiff); Type indexTy = lb.getIndexType(); broadcastedDim = - lb.create( + IfOp::create(lb, outOfBounds, [&](OpBuilder &b, Location loc) { - b.create(loc, broadcastedDim); + scf::YieldOp::create(b, loc, broadcastedDim); }, [&](OpBuilder &b, Location loc) { // The broadcasting logic is: @@ -104,17 +104,17 @@ Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, // - otherwise, take the extent as-is. // Note that this logic remains correct in the presence // of dimensions of zero extent. - Value lesserRankOperandDimension = b.create( + Value lesserRankOperandDimension = arith::SubIOp::create(b, loc, indexTy, outputDimension, rankDiff); - Value lesserRankOperandExtent = b.create( + Value lesserRankOperandExtent = tensor::ExtractOp::create(b, loc, shape, ValueRange{lesserRankOperandDimension}); Value dimIsOne = - b.create(loc, arith::CmpIPredicate::eq, + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, lesserRankOperandExtent, one); - Value dim = b.create( + Value dim = arith::SelectOp::create(b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); - b.create(loc, dim); + scf::YieldOp::create(b, loc, dim); }) .getResult(0); } @@ -133,7 +133,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); + Value zero = arith::ConstantIndexOp::create(lb, 0); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -141,31 +141,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( // dimension in the tensor. SmallVector ranks, rankDiffs; llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { - return lb.create(v, zero); + return tensor::DimOp::create(lb, v, zero); })); // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - maxRank = lb.create(v, maxRank); + maxRank = arith::MaxUIOp::create(lb, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return arith::SubIOp::create(lb, indexTy, maxRank, v); })); - Value replacement = lb.create( + Value replacement = tensor::GenerateOp::create(lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value broadcastedDim = getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, args[0]); - b.create(loc, broadcastedDim); + tensor::YieldOp::create(b, loc, broadcastedDim); }); if (replacement.getType() != op.getType()) - replacement = lb.create(op.getType(), replacement); + replacement = tensor::CastOp::create(lb, op.getType(), replacement); rewriter.replaceOp(op, replacement); return success(); } @@ -194,12 +194,12 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite( SmallVector extentOperands; for (auto extent : op.getShape()) { extentOperands.push_back( - rewriter.create(loc, extent.getLimitedValue())); + arith::ConstantIndexOp::create(rewriter, loc, extent.getLimitedValue())); } Type resultTy = RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); Value tensor = - rewriter.create(loc, resultTy, extentOperands); + tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands); rewriter.replaceOpWithNewOp(op, resultTy, tensor); return success(); } @@ -245,8 +245,8 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); - Value one = lb.create(1); + Value zero = arith::ConstantIndexOp::create(lb, 0); + Value one = arith::ConstantIndexOp::create(lb, 1); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -254,25 +254,25 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( // dimension in the tensor. SmallVector ranks, rankDiffs; llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { - return lb.create(v, zero); + return tensor::DimOp::create(lb, v, zero); })); // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - maxRank = lb.create(v, maxRank); + maxRank = arith::MaxUIOp::create(lb, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return arith::SubIOp::create(lb, indexTy, maxRank, v); })); Type i1Ty = rewriter.getI1Type(); Value trueVal = - rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); + arith::ConstantOp::create(rewriter, loc, i1Ty, rewriter.getBoolAttr(true)); - auto reduceResult = lb.create( + auto reduceResult = ForOp::create(lb, loc, zero, maxRank, one, ValueRange{trueVal}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { // Find a non-1 dim, if it exists. Note that the first part of this @@ -285,38 +285,38 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; - Value outOfBounds = b.create( + Value outOfBounds = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = - b.create( + IfOp::create(b, loc, outOfBounds, [&](OpBuilder &b, Location loc) { // Non existent dimensions are always broadcastable - b.create(loc, broadcastable); + scf::YieldOp::create(b, loc, broadcastable); }, [&](OpBuilder &b, Location loc) { // Every value needs to be either 1, or the same non-1 // value to be broadcastable in this dim. Value operandDimension = - b.create(loc, indexTy, iv, rankDiff); - Value dimensionExtent = b.create( + arith::SubIOp::create(b, loc, indexTy, iv, rankDiff); + Value dimensionExtent = tensor::ExtractOp::create(b, loc, shape, ValueRange{operandDimension}); - Value equalOne = b.create( + Value equalOne = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, dimensionExtent, one); - Value equalBroadcasted = b.create( + Value equalBroadcasted = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, dimensionExtent, broadcastedDim); - Value result = b.create( + Value result = arith::AndIOp::create(b, loc, broadcastable, - b.create(loc, equalOne, + arith::OrIOp::create(b, loc, equalOne, equalBroadcasted)); - b.create(loc, result); + scf::YieldOp::create(b, loc, result); }) .getResult(0); } - b.create(loc, broadcastable); + scf::YieldOp::create(b, loc, broadcastable); }); rewriter.replaceOp(op, reduceResult.getResults().front()); @@ -339,7 +339,7 @@ DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor, // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further // lowerings. This can be further optimized if needed to avoid intermediate // steps. - auto shapeOf = rewriter.create(op.getLoc(), op.getValue()); + auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue()); rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, op.getIndex()); return success(); @@ -421,16 +421,16 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = - rewriter.create(loc, indexTy, adaptor.getShape(), zero); + tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero); - auto loop = rewriter.create( + auto loop = scf::ForOp::create(rewriter, loc, zero, rank, one, op.getInitVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = b.create(loc, adaptor.getShape(), iv); + Value extent = tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); @@ -444,7 +444,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, SmallVector mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); - b.create(loc, mappedResults); + scf::YieldOp::create(b, loc, mappedResults); }); rewriter.replaceOp(op, loop.getResults()); @@ -507,43 +507,43 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value firstShape = adaptor.getShapes().front(); Value firstRank = - rewriter.create(loc, indexTy, firstShape, zero); + tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero); Value result = nullptr; // Generate a linear sequence of compares, all with firstShape as lhs. for (Value shape : adaptor.getShapes().drop_front(1)) { - Value rank = rewriter.create(loc, indexTy, shape, zero); - Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, + Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero); + Value eqRank = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank); - auto same = rewriter.create( + auto same = IfOp::create(rewriter, loc, eqRank, [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); + Value one = arith::ConstantIndexOp::create(b, loc, 1); Value init = - b.create(loc, i1Ty, b.getBoolAttr(true)); - auto loop = b.create( + arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true)); + auto loop = scf::ForOp::create(b, loc, zero, firstRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { Value conj = args[0]; Value lhsExtent = - b.create(loc, firstShape, iv); - Value rhsExtent = b.create(loc, shape, iv); - Value eqExtent = b.create( + tensor::ExtractOp::create(b, loc, firstShape, iv); + Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv); + Value eqExtent = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); - b.create(loc, ValueRange({conjNext})); + Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent); + scf::YieldOp::create(b, loc, ValueRange({conjNext})); }); - b.create(loc, loop.getResults()); + scf::YieldOp::create(b, loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { Value result = - b.create(loc, i1Ty, b.getBoolAttr(false)); - b.create(loc, result); + arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false)); + scf::YieldOp::create(b, loc, result); }); result = !result ? same.getResult(0) - : rewriter.create(loc, result, + : arith::AndIOp::create(rewriter, loc, result, same.getResult(0)); } rewriter.replaceOp(op, result); @@ -581,17 +581,17 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( int64_t rank = rankedTensorTy.getRank(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { - Value extent = rewriter.create(loc, tensor, i); + Value extent = tensor::DimOp::create(rewriter, loc, tensor, i); extentValues.push_back(extent); } else { - Value extent = rewriter.create( + Value extent = arith::ConstantIndexOp::create(rewriter, loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } // Materialize extent tensor. - Value staticExtentTensor = rewriter.create( + Value staticExtentTensor = tensor::FromElementsOp::create(rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()), extentValues); rewriter.replaceOpWithNewOp(op, op.getType(), @@ -601,13 +601,13 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); - Value rank = rewriter.create(loc, tensor); + Value rank = tensor::RankOp::create(rewriter, loc, tensor); rewriter.replaceOpWithNewOp( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim = args.front(); - Value extent = b.create(loc, tensor, dim); - b.create(loc, extent); + Value extent = tensor::DimOp::create(b, loc, tensor, dim); + tensor::YieldOp::create(b, loc, extent); }); return success(); @@ -634,21 +634,21 @@ LogicalResult SplitAtOpConversion::matchAndRewrite( return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create(0); - Value rank = b.create(adaptor.getOperand(), zero); + Value zero = arith::ConstantIndexOp::create(b, 0); + Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero); // index < 0 ? index + rank : index Value originalIndex = adaptor.getIndex(); - Value add = b.create(originalIndex, rank); + Value add = arith::AddIOp::create(b, originalIndex, rank); Value indexIsNegative = - b.create(arith::CmpIPredicate::slt, originalIndex, zero); - Value index = b.create(indexIsNegative, add, originalIndex); + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero); + Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex); - Value one = b.create(1); + Value one = arith::ConstantIndexOp::create(b, 1); Value head = - b.create(adaptor.getOperand(), zero, index, one); - Value tailSize = b.create(rank, index); - Value tail = b.create(adaptor.getOperand(), index, + tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one); + Value tailSize = arith::SubIOp::create(b, rank, index); + Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index, tailSize, one); rewriter.replaceOp(op, {head, tail}); return success(); diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp index 2c4d27502a521..f06c3e1f7cb18 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -68,10 +68,10 @@ class TensorExtractPattern final // We could use the initializer directly; but certain driver compilers // have bugs dealing with that. So for now, use spirv.Store for // initialization. - varOp = rewriter.create(loc, varType, + varOp = spirv::VariableOp::create(rewriter, loc, varType, spirv::StorageClass::Function, /*initializer=*/nullptr); - rewriter.create(loc, varOp, adaptor.getTensor()); + spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. @@ -83,7 +83,7 @@ class TensorExtractPattern final Value index = spirv::linearizeIndex(adaptor.getIndices(), strides, /*offset=*/0, indexType, loc, rewriter); - auto acOp = rewriter.create(loc, varOp, index); + auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 40ad63610e23f..44872c8d7540d 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -51,7 +51,7 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { Value getConstantValue(Location loc, Type type, int64_t value, PatternRewriter &rewriter) { - return rewriter.create( + return arith::ConstantOp::create(rewriter, loc, getConstantAttr(type, value, rewriter)); } @@ -82,41 +82,41 @@ class ApplyScaleGenericOpConverter Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); - Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Compute the multiplication in 64-bits then select the high / low parts. Value value64 = value; if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) - value64 = rewriter.create(loc, i64Ty, value); + value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value); Value multiplier64 = - rewriter.create(loc, i64Ty, multiplier32); + arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32); Value multiply64 = - rewriter.create(loc, value64, multiplier64); + arith::MulIOp::create(rewriter, loc, value64, multiplier64); // Apply normal rounding. - Value shift64 = rewriter.create(loc, i64Ty, shift32); - Value round = rewriter.create(loc, one64, shift64); - round = rewriter.create(loc, round, one64); - multiply64 = rewriter.create(loc, multiply64, round); + Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32); + Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64); + round = arith::ShRUIOp::create(rewriter, loc, round, one64); + multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round); // Apply double rounding if necessary. if (op.getRoundingMode() == "DOUBLE_ROUND") { int64_t roundInt = 1 << 30; Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); - Value positive = rewriter.create( + Value positive = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, value, zero); Value dir = - rewriter.create(loc, positive, roundUp, roundDown); - Value val = rewriter.create(loc, dir, multiply64); - Value valid = rewriter.create( + arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown); + Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64); + Value valid = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); multiply64 = - rewriter.create(loc, valid, val, multiply64); + arith::SelectOp::create(rewriter, loc, valid, val, multiply64); } - Value result64 = rewriter.create(loc, multiply64, shift64); - Value result32 = rewriter.create(loc, i32Ty, result64); + Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64); + Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64); rewriter.replaceOp(op, result32); return success(); @@ -146,7 +146,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { Value value32 = op.getValue(); Value multiplier32 = op.getMultiplier(); - Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Constants used during the scaling operation. Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); @@ -158,86 +158,86 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { // Compute the multiplication in 64-bits then select the high / low parts. // Grab out the high/low of the computation auto value64 = - rewriter.create(loc, value32, multiplier32); + arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32); Value low32 = value64.getLow(); Value high32 = value64.getHigh(); // Determine the direction and amount to shift the high bits. - Value shiftOver32 = rewriter.create( + Value shiftOver32 = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); - Value roundHighBits = rewriter.create( + Value roundHighBits = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); Value shiftHighL = - rewriter.create(loc, thirtyTwo32, shift32); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32); Value shiftHighR = - rewriter.create(loc, shift32, thirtyTwo32); + arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32); shiftHighL = - rewriter.create(loc, shiftOver32, zero32, shiftHighL); + arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL); shiftHighR = - rewriter.create(loc, shiftOver32, shiftHighR, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32); // Conditionally perform our double round. if (op.getRoundingMode() == "DOUBLE_ROUND") { Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); - Value valuePositive = rewriter.create( + Value valuePositive = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, value32, zero32); Value roundDir = - rewriter.create(loc, valuePositive, one32, negOne32); + arith::SelectOp::create(rewriter, loc, valuePositive, one32, negOne32); roundDir = - rewriter.create(loc, shiftOver32, roundDir, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32); - Value shiftLow = rewriter.create(loc, low32, thirty32); - Value rounded = rewriter.create(loc, shiftLow, roundDir); - Value carry = rewriter.create(loc, rounded, two32); + Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32); + Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir); + Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32); Value shiftRound = - rewriter.create(loc, roundDir, thirty32); + arith::ShLIOp::create(rewriter, loc, roundDir, thirty32); - low32 = rewriter.create(loc, low32, shiftRound); - high32 = rewriter.create(loc, high32, carry); + low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound); + high32 = arith::AddIOp::create(rewriter, loc, high32, carry); } // Conditionally apply rounding in the low bits. { - Value shiftSubOne = rewriter.create(loc, shift32, one32); - Value roundBit = rewriter.create(loc, one32, shiftSubOne); - roundBit = rewriter.create(loc, roundHighBits, zero32, + Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32, roundBit); - Value newLow32 = rewriter.create(loc, low32, roundBit); - Value wasRounded = rewriter.create( + Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit); + Value wasRounded = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32); low32 = newLow32; - Value rounded32 = rewriter.create(loc, i32Ty, wasRounded); - high32 = rewriter.create(loc, high32, rounded32); + Value rounded32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded); + high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32); } // Conditionally apply rounding in the high bits. { Value shiftSubOne = - rewriter.create(loc, shiftHighR, one32); - Value roundBit = rewriter.create(loc, one32, shiftSubOne); - roundBit = rewriter.create(loc, roundHighBits, roundBit, + arith::SubIOp::create(rewriter, loc, shiftHighR, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit, zero32); - high32 = rewriter.create(loc, high32, roundBit); + high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit); } // Combine the correct high/low bits into the final rescale result. - high32 = rewriter.create(loc, high32, shiftHighL); - high32 = rewriter.create(loc, high32, shiftHighR); - low32 = rewriter.create(loc, low32, shift32); - low32 = rewriter.create(loc, shiftOver32, zero32, low32); + high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL); + high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR); + low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32); + low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32); // Apply the rounding behavior and shift to the final alignment. - Value result = rewriter.create(loc, low32, high32); + Value result = arith::AddIOp::create(rewriter, loc, low32, high32); // Truncate if necessary. if (!getElementTypeOrSelf(resultTy).isInteger(32)) { - result = rewriter.create(loc, resultTy, result); + result = arith::TruncIOp::create(rewriter, loc, resultTy, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 2f608bbd637b4..2aec35ed9d8ca 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -70,13 +70,13 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, return result; // Unordered comparison of NaN against itself will always return true. - Value lhsIsNaN = rewriter.create( + Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs); - Value rhsIsNaN = rewriter.create( + Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs); Value rhsOrResult = - rewriter.create(op.getLoc(), lhsIsNaN, rhs, result); - return rewriter.create(op.getLoc(), rhsIsNaN, lhs, + arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result); + return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs, rhsOrResult); } @@ -89,38 +89,38 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::AbsOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return math::AbsFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) { - auto zero = rewriter.create( + auto zero = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(elementTy)); - auto neg = rewriter.create(loc, zero, args[0]); - return rewriter.create(loc, args[0], neg); + auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]); + return arith::MaxSIOp::create(rewriter, loc, args[0], neg); } // tosa::AddOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::AddFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::AddIOp::create(rewriter, loc, resultTypes, args); // tosa::SubOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::SubFOp::create(rewriter, loc, resultTypes, args); if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::SubIOp::create(rewriter, loc, resultTypes, args); // tosa::IntDivOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::DivSIOp::create(rewriter, loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && isa(elementTy)) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - return rewriter.create(loc, resultTypes, one, args[0]); + arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); + return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]); } // tosa::MulOp @@ -140,7 +140,7 @@ static Value createLinalgBodyCalculationForElementwiseOp( "Cannot have shift value for float"); return nullptr; } - return rewriter.create(loc, resultTypes, args[0], args[1]); + return arith::MulFOp::create(rewriter, loc, resultTypes, args[0], args[1]); } if (isa(elementTy)) { @@ -149,21 +149,21 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (shift > 0) { auto shiftConst = - rewriter.create(loc, shift, /*bitwidth=*/8); + arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) - a = rewriter.create(loc, rewriter.getI32Type(), a); + a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) - b = rewriter.create(loc, rewriter.getI32Type(), b); + b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b); - auto result = rewriter.create( + auto result = tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a, b, shiftConst, rewriter.getStringAttr("SINGLE_ROUND")); if (elementTy.isInteger(32)) return result; - return rewriter.create(loc, elementTy, result); + return arith::TruncIOp::create(rewriter, loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -171,11 +171,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) - a = rewriter.create(loc, resultTypes[0], a); + a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a); if (bWidth < cWidth) - b = rewriter.create(loc, resultTypes[0], b); + b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b); - return rewriter.create(loc, resultTypes, a, b); + return arith::MulIOp::create(rewriter, loc, resultTypes, a, b); } } @@ -201,13 +201,13 @@ static Value createLinalgBodyCalculationForElementwiseOp( int64_t outZp = *maybeOutZp; if (isa(elementTy)) - return rewriter.create(loc, resultTypes, args[0]); + return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa(elementTy)) { if (!inZp && !outZp) { - auto constant = rewriter.create( + auto constant = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create(loc, resultTypes, constant, + return arith::SubIOp::create(rewriter, loc, resultTypes, constant, args[0]); } @@ -231,60 +231,60 @@ static Value createLinalgBodyCalculationForElementwiseOp( } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create( + Value zpAddValue = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue auto ext = - rewriter.create(loc, intermediateType, args[0]); - auto sub = rewriter.create(loc, zpAddValue, ext); + arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]); + auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext); // Clamp to the negation range. - Value min = rewriter.create( + Value min = arith::ConstantIntOp::create(rewriter, loc, intermediateType, APInt::getSignedMinValue(inputBitWidth).getSExtValue()); - Value max = rewriter.create( + Value max = arith::ConstantIntOp::create(rewriter, loc, intermediateType, APInt::getSignedMaxValue(inputBitWidth).getSExtValue()); auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false); // Truncate to the final value. - return rewriter.create(loc, elementTy, clamp); + return arith::TruncIOp::create(rewriter, loc, elementTy, clamp); } } // tosa::BitwiseAndOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && isa(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); - auto allOnes = rewriter.create(loc, allOnesAttr); - return rewriter.create(loc, resultTypes, args[0], allOnes); + auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::ShLIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return arith::ShRUIOp::create(rewriter, loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && isa(elementTy)) { - auto result = rewriter.create(loc, resultTypes, args); + auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args); auto round = cast(op->getAttr("round")).getValue(); if (!round) { return result; @@ -292,153 +292,153 @@ static Value createLinalgBodyCalculationForElementwiseOp( Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = - rewriter.create(loc, IntegerAttr::get(elementTy, 1)); + arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(elementTy, 1)); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(elementTy, 0)); auto i1one = - rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); + arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 - auto shiftValueGreaterThanZero = rewriter.create( + auto shiftValueGreaterThanZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = - rewriter.create(loc, resultTypes, args[1], one); + arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one); auto shifted = - rewriter.create(loc, resultTypes, args[0], subtract) + arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract) ->getResults(); - auto truncated = rewriter.create( + auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted, ArrayRef()); auto isInputOdd = - rewriter.create(loc, i1Ty, truncated, i1one); + arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one); - auto shouldRound = rewriter.create( + auto shouldRound = arith::AndIOp::create(rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = - rewriter.create(loc, resultTypes, shouldRound); - return rewriter.create(loc, resultTypes, result, extended); + arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound); + return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, elementTy, args[0]); + return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]); } // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { - auto one = rewriter.create( + auto one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(elementTy, 1)); - return rewriter.create(loc, resultTypes, args[0], one); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::PowOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args); // tosa::RsqrtOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args); // tosa::LogOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::LogOp::create(rewriter, loc, resultTypes, args); // tosa::ExpOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args); // tosa::SinOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::SinOp::create(rewriter, loc, resultTypes, args); // tosa::CosOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::CosOp::create(rewriter, loc, resultTypes, args); // tosa::TanhOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args); // tosa::ErfOp if (isa(op) && llvm::isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args); // tosa::GreaterOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, arith::CmpFPredicate::OGT, + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, arith::CmpIPredicate::sgt, + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, arith::CmpFPredicate::OGE, + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, arith::CmpIPredicate::sge, + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, args[0], args[1]); // tosa::EqualOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, arith::CmpFPredicate::OEQ, + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, arith::CmpIPredicate::eq, + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, args[0], args[1]); // tosa::SelectOp if (isa(op)) { elementTy = cast(op->getOperand(1).getType()).getElementType(); if (isa(elementTy) || isa(elementTy)) - return rewriter.create(loc, args[0], args[1], args[2]); + return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa(op) && isa(elementTy)) { - auto max = rewriter.create(loc, args[0], args[1]); + auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast(op), rewriter, args[0], args[1], max); } if (isa(op) && elementTy.isSignlessInteger()) { - return rewriter.create(loc, args[0], args[1]); + return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && isa(elementTy)) { - auto min = rewriter.create(loc, args[0], args[1]); + auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast(op), rewriter, args[0], args[1], min); } if (isa(op) && elementTy.isSignlessInteger()) { - return rewriter.create(loc, args[0], args[1]); + return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::CeilOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return math::CeilOp::create(rewriter, loc, resultTypes, args); // tosa::FloorOp if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + return math::FloorOp::create(rewriter, loc, resultTypes, args); // tosa::ClampOp if (isa(op) && isa(elementTy)) { @@ -449,9 +449,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( APFloat::rmNearestTiesToEven, &losesInfo); maxApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); - auto min = rewriter.create( + auto min = arith::ConstantOp::create(rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); - auto max = rewriter.create( + auto max = arith::ConstantOp::create(rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); auto result = clampFloatHelper(loc, args[0], min, max, rewriter); @@ -478,11 +478,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( // return init if x == NaN else result // Unordered comparison of NaN against itself will always return true. - Value isNaN = rewriter.create( + Value isNaN = arith::CmpFOp::create(rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); // TOSA specifies that in "ignore" NaN mode the result is "min" if the input // is NaN. - return rewriter.create(op->getLoc(), isNaN, min, result); + return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result); } if (isa(op) && isa(elementTy)) { @@ -515,9 +515,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( min = std::min(min, maxRepresentable); max = std::min(max, maxRepresentable); - auto minVal = rewriter.create( + auto minVal = arith::ConstantIntOp::create(rewriter, loc, min, intTy.getIntOrFloatBitWidth()); - auto maxVal = rewriter.create( + auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max, intTy.getIntOrFloatBitWidth()); return clampIntHelper(loc, args[0], minVal, maxVal, rewriter, intTy.isUnsignedInteger()); @@ -526,11 +526,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::SigmoidOp if (isa(op) && isa(elementTy)) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - auto negate = rewriter.create(loc, resultTypes, args[0]); - auto exp = rewriter.create(loc, resultTypes, negate); - auto added = rewriter.create(loc, resultTypes, exp, one); - return rewriter.create(loc, resultTypes, one, added); + arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); + auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); + auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate); + auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one); + return arith::DivFOp::create(rewriter, loc, resultTypes, one, added); } // tosa::CastOp @@ -549,20 +549,20 @@ static Value createLinalgBodyCalculationForElementwiseOp( return args.front(); if (isa(srcTy) && isa(dstTy) && bitExtend) - return rewriter.create(loc, resultTypes, args, + return arith::ExtFOp::create(rewriter, loc, resultTypes, args, ArrayRef()); if (isa(srcTy) && isa(dstTy) && !bitExtend) - return rewriter.create(loc, resultTypes, args, + return arith::TruncFOp::create(rewriter, loc, resultTypes, args, ArrayRef()); // 1-bit integers need to be treated as signless. if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, + return arith::UIToFPOp::create(rewriter, loc, resultTypes, args, ArrayRef()); if (srcTy.isInteger(1) && isa(dstTy) && bitExtend) - return rewriter.create(loc, resultTypes, args, + return arith::ExtUIOp::create(rewriter, loc, resultTypes, args, ArrayRef()); // Unsigned integers need an unrealized cast so that they can be passed @@ -574,25 +574,25 @@ static Value createLinalgBodyCalculationForElementwiseOp( loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); - return rewriter.create(loc, resultTypes[0], + return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], unrealizedCast); } // All other si-to-fp conversions should be handled by SIToFP. if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, + return arith::SIToFPOp::create(rewriter, loc, resultTypes, args, ArrayRef()); // Casting to boolean, floats need to only be checked as not-equal to zero. if (isa(srcTy) && dstTy.isInteger(1)) { - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(srcTy, 0.0)); - return rewriter.create(loc, arith::CmpFPredicate::UNE, + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE, args.front(), zero); } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto rounded = rewriter.create(loc, args[0]); + auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]); const auto &fltSemantics = cast(srcTy).getFloatSemantics(); // Check whether neither int min nor int max can be represented in the @@ -601,33 +601,33 @@ static Value createLinalgBodyCalculationForElementwiseOp( APFloat::semanticsMaxExponent(fltSemantics)) { // Use cmp + select to replace infinites by int min / int max. Other // integral values can be represented in the integer space. - auto conv = rewriter.create(loc, dstTy, rounded); - auto posInf = rewriter.create( + auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded); + auto posInf = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), APFloat::getInf(fltSemantics))); - auto negInf = rewriter.create( + auto negInf = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APFloat::getInf(fltSemantics, /*Negative=*/true))); - auto overflow = rewriter.create( + auto overflow = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf); - auto underflow = rewriter.create( + auto underflow = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf); - auto intMin = rewriter.create( + auto intMin = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); - auto intMax = rewriter.create( + auto intMax = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto maxClamped = - rewriter.create(loc, overflow, intMax, conv); - return rewriter.create(loc, underflow, intMin, + arith::SelectOp::create(rewriter, loc, overflow, intMax, conv); + return arith::SelectOp::create(rewriter, loc, underflow, intMin, maxClamped); } - auto intMinFP = rewriter.create( + auto intMinFP = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) @@ -640,7 +640,7 @@ static Value createLinalgBodyCalculationForElementwiseOp( // consists of a single leading bit. Therefore we can clamp the input // in the floating-point domain. - auto intMaxFP = rewriter.create( + auto intMaxFP = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) @@ -648,14 +648,14 @@ static Value createLinalgBodyCalculationForElementwiseOp( Value clamped = clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); - return rewriter.create(loc, dstTy, clamped); + return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped); } // Due to earlier check we know exponant range is big enough to represent // int min. We can therefore rely on int max + 1 being representable as // well because it's just int min with a positive sign. So clamp the min // value and compare against that to select the max int value if needed. - auto intMaxPlusOneFP = rewriter.create( + auto intMaxPlusOneFP = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), static_cast( @@ -663,35 +663,35 @@ static Value createLinalgBodyCalculationForElementwiseOp( .getSExtValue()) + 1.0f)); - auto intMax = rewriter.create( + auto intMax = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto minClampedFP = - rewriter.create(loc, rounded, intMinFP); + arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP); auto minClamped = - rewriter.create(loc, dstTy, minClampedFP); - auto overflow = rewriter.create( + arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP); + auto overflow = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); - return rewriter.create(loc, overflow, intMax, + return arith::SelectOp::create(rewriter, loc, overflow, intMax, minClamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (isa(srcTy) && dstTy.isInteger(1)) { - Value zero = rewriter.create( + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, srcTy.getIntOrFloatBitWidth()); - return rewriter.create(loc, arith::CmpIPredicate::ne, + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, args.front(), zero); } if (isa(srcTy) && isa(dstTy) && bitExtend) - return rewriter.create(loc, resultTypes, args, + return arith::ExtSIOp::create(rewriter, loc, resultTypes, args, ArrayRef()); if (isa(srcTy) && isa(dstTy) && !bitExtend) { - return rewriter.create(loc, dstTy, args[0]); + return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]); } } @@ -710,14 +710,14 @@ static Value createIndex(PatternRewriter &rewriter, Location loc, auto [it, inserted] = indexPool.try_emplace(index); if (inserted) it->second = - rewriter.create(loc, rewriter.getIndexAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index)); return it->second; } static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto indexValue = createIndex(rewriter, loc, indexPool, index); - return rewriter.create(loc, tensor, indexValue).getResult(); + return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult(); } static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, @@ -783,7 +783,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { auto nextSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim); - targetSize = rewriter.create(loc, targetSize, nextSize); + targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize); } return {targetSize, nullptr}; } @@ -838,7 +838,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Check if broadcast is necessary auto one = createIndex(rewriter, loc, indexPool, 1); auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim); - auto broadcastNecessary = rewriter.create( + auto broadcastNecessary = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one); // Emit 'then' region of 'scf.if' @@ -855,7 +855,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, operand, index); outputTensorShape.push_back(size); } - Value outputTensor = opBuilder.create( + Value outputTensor = tensor::EmptyOp::create(opBuilder, loc, outputTensorShape, rankedTensorType.getElementType()); // Emit 'linalg.generic' op @@ -866,7 +866,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { // Emit 'linalg.yield' op - opBuilder.create(loc, blockArgs.front()); + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); }) .getResult(0); @@ -875,16 +875,16 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, loc, operand.getType(), resultTensor); // Emit 'scf.yield' op - opBuilder.create(loc, castResultTensor); + scf::YieldOp::create(opBuilder, loc, castResultTensor); }; // Emit 'else' region of 'scf.if' auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { - opBuilder.create(loc, operand); + scf::YieldOp::create(opBuilder, loc, operand); }; // Emit 'scf.if' op - auto ifOp = rewriter.create(loc, broadcastNecessary, + auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary, emitThenRegion, emitElseRegion); return ifOp.getResult(0); } @@ -930,7 +930,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, if (!resultType) { return rewriter.notifyMatchFailure(operation, "failed to convert type"); } - Value outputTensor = rewriter.create( + Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape, resultType.getElementType()); // Create affine maps. Input affine maps broadcast static dimensions of size @@ -957,7 +957,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op bool encounteredError = false; - auto linalgOp = rewriter.create( + auto linalgOp = linalg::GenericOp::create(rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { @@ -968,7 +968,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, encounteredError = true; return; } - opBuilder.create(loc, opResult); + linalg::YieldOp::create(opBuilder, loc, opResult); }); if (encounteredError) return rewriter.notifyMatchFailure( @@ -1078,42 +1078,42 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::AddFOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::AddIOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::MulFOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args); + return arith::MulIOp::create(rewriter, loc, args); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return arith::AndIOp::create(rewriter, loc, args); if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return arith::OrIOp::create(rewriter, loc, args); return {}; } @@ -1139,7 +1139,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, if (axis != i) { reduceShape.push_back(inputTy.getDimSize(i)); if (inputTy.isDynamicDim(i)) - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -1158,7 +1158,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) @@ -1176,7 +1176,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, // Additionally we have to keep track of whether we've seen any non-NaN // values and then do a final select based on this predicate. auto trueAttr = rewriter.getBoolAttr(true); - auto trueValue = rewriter.create(loc, trueAttr); + auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = rewriter .create(loc, reduceShape, trueValue.getType(), @@ -1202,7 +1202,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, } bool didEncounterError = false; - linalg::LinalgOp linalgOp = rewriter.create( + linalg::LinalgOp linalgOp = linalg::ReduceOp::create(rewriter, loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { std::array binaryArgs{ @@ -1219,21 +1219,21 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto oldAllResultsNanFlagValue = blockArgs[3]; // Unordered comparison of NaN against itself will always return true. - Value isNaN = nestedBuilder.create( + Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue); // If we've encountered a NaN, take the non-NaN value. - auto selectOp = nestedBuilder.create( + auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(), isNaN, initialValue, result); // Update the flag which keeps track of whether we have seen a non-NaN // value. - auto newAllResultsNanFlagValue = nestedBuilder.create( + auto newAllResultsNanFlagValue = arith::AndIOp::create(nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN); resultsToYield.push_back(selectOp); resultsToYield.push_back(newAllResultsNanFlagValue); } else { resultsToYield.push_back(result); } - nestedBuilder.create(loc, resultsToYield); + linalg::YieldOp::create(nestedBuilder, loc, resultsToYield); }); if (!didEncounterError) @@ -1250,7 +1250,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto nanValueAttr = rewriter.getFloatAttr( elementTy, APFloat::getNaN(cast(elementTy).getFloatSemantics(), false)); - auto nanValue = rewriter.create(loc, nanValueAttr); + auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = rewriter .create(loc, reduceShape, @@ -1278,7 +1278,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, ins.push_back(linalgOp->getResult(0)); outs.push_back(finalEmptyTensor); auto linalgSelect = - rewriter.create(op->getLoc(), ins, outs); + linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs); linalgOp = linalgSelect; } @@ -1350,7 +1350,7 @@ class RescaleConverter : public OpRewritePattern { SmallVector dynDims; for (int i = 0; i < outputTy.getRank(); i++) { if (outputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -1401,7 +1401,7 @@ class RescaleConverter : public OpRewritePattern { Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { - multiplierConstant = rewriter.create( + multiplierConstant = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ @@ -1409,7 +1409,7 @@ class RescaleConverter : public OpRewritePattern { auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(arith::ConstantOp::create(rewriter, loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, @@ -1424,7 +1424,7 @@ class RescaleConverter : public OpRewritePattern { Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { - shiftConstant = rewriter.create( + shiftConstant = arith::ConstantOp::create(rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { @@ -1432,7 +1432,7 @@ class RescaleConverter : public OpRewritePattern { auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(arith::ConstantOp::create(rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, @@ -1444,11 +1444,11 @@ class RescaleConverter : public OpRewritePattern { indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), outputTy.getElementType(), ArrayRef({dynDims})); - auto linalgOp = rewriter.create( + auto linalgOp = linalg::GenericOp::create(rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, @@ -1466,7 +1466,7 @@ class RescaleConverter : public OpRewritePattern { const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); // Extend zeropoint for sub-32bits widths. const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = nestedBuilder.create( + auto inputZp = arith::ConstantOp::create(nestedBuilder, loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), *maybeIZp)); @@ -1482,7 +1482,7 @@ class RescaleConverter : public OpRewritePattern { unsigned outBitWidth = outIntType.getWidth(); const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = nestedBuilder.create( + auto outputZp = arith::ConstantOp::create(nestedBuilder, loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), *maybeOZp)); @@ -1501,24 +1501,24 @@ class RescaleConverter : public OpRewritePattern { } if (valueTy.getIntOrFloatBitWidth() < 32) { if (op.getInputUnsigned()) { - value = nestedBuilder.create( + value = arith::ExtUIOp::create(nestedBuilder, nestedLoc, nestedBuilder.getI32Type(), value); } else { - value = nestedBuilder.create( + value = arith::ExtSIOp::create(nestedBuilder, nestedLoc, nestedBuilder.getI32Type(), value); } } value = - nestedBuilder.create(nestedLoc, value, inputZp); + arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp); - value = nestedBuilder.create( + value = tosa::ApplyScaleOp::create(nestedBuilder, loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode); // Move to the new zero-point. value = - nestedBuilder.create(nestedLoc, value, outputZp); + arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp); // Saturate to the output size. int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); @@ -1530,16 +1530,16 @@ class RescaleConverter : public OpRewritePattern { intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } - auto intMinVal = nestedBuilder.create( + auto intMinVal = arith::ConstantOp::create(nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin)); - auto intMaxVal = nestedBuilder.create( + auto intMaxVal = arith::ConstantOp::create(nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax)); value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, nestedBuilder, /*isUnsigned=*/false); if (outIntType.getWidth() < 32) { - value = nestedBuilder.create( + value = arith::TruncIOp::create(nestedBuilder, nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), value); } @@ -1550,7 +1550,7 @@ class RescaleConverter : public OpRewritePattern { outIntType, value) .getResult(0); } - nestedBuilder.create(loc, value); + linalg::YieldOp::create(nestedBuilder, loc, value); }); rewriter.replaceOp(op, linalgOp->getResults()); @@ -1608,25 +1608,25 @@ class ResizeUnaryConverter : public OpRewritePattern { auto collapseTy = RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, inputTy.getElementType()); - Value collapse = builder.create(collapseTy, input, + Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input, reassociationMap); // Get any dynamic shapes that appear in the input format. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) - outputDynSize.push_back(builder.create(input, 0)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) - outputDynSize.push_back(builder.create(input, 3)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); // Generate the elementwise operation for casting scaling the input value. auto genericTy = collapseTy.clone(resultTy.getElementType()); - Value empty = builder.create( + Value empty = tensor::EmptyOp::create(builder, genericTy.getShape(), resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); SmallVector iterators(genericTy.getRank(), utils::IteratorType::parallel); - auto generic = builder.create( + auto generic = linalg::GenericOp::create(builder, genericTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef{genericMap, genericMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { @@ -1634,22 +1634,22 @@ class ResizeUnaryConverter : public OpRewritePattern { // This is the quantized case. if (inputTy.getElementType() != resultTy.getElementType()) { value = - b.create(loc, resultTy.getElementType(), value); + arith::ExtSIOp::create(b, loc, resultTy.getElementType(), value); if (isBilinear && scale[0] != 0) { - Value scaleY = b.create( + Value scaleY = arith::ConstantOp::create(b, loc, b.getI32IntegerAttr(scale[0])); - value = b.create(loc, value, scaleY); + value = arith::MulIOp::create(b, loc, value, scaleY); } if (isBilinear && scale[2] != 0) { - Value scaleX = b.create( + Value scaleX = arith::ConstantOp::create(b, loc, b.getI32IntegerAttr(scale[2])); - value = b.create(loc, value, scaleX); + value = arith::MulIOp::create(b, loc, value, scaleX); } } - b.create(loc, value); + linalg::YieldOp::create(b, loc, value); }); rewriter.replaceOpWithNewOp( @@ -1697,7 +1697,7 @@ class MaterializeResizeBroadcast : public OpRewritePattern { resizeShape.push_back(channels); auto resizeTy = resultTy.clone(resizeShape); - auto resize = builder.create(resizeTy, input, op.getScale(), + auto resize = tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(), op.getOffset(), op.getBorder(), op.getMode()); @@ -1720,19 +1720,19 @@ class MaterializeResizeBroadcast : public OpRewritePattern { collapseShape.push_back(channels); auto collapseTy = resultTy.clone(collapseShape); - Value collapse = builder.create(collapseTy, resize, + Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, resize, reassociationMap); // Broadcast the collapsed shape to the output result. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) - outputDynSize.push_back(builder.create(input, 0)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) - outputDynSize.push_back(builder.create(input, 3)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); SmallVector iterators(resultTy.getRank(), utils::IteratorType::parallel); - Value empty = builder.create( + Value empty = tensor::EmptyOp::create(builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize); SmallVector inputExprs{rewriter.getAffineDimExpr(0)}; @@ -1751,7 +1751,7 @@ class MaterializeResizeBroadcast : public OpRewritePattern { ArrayRef{inputMap, outputMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; - b.create(loc, value); + linalg::YieldOp::create(b, loc, value); }); return success(); @@ -1789,9 +1789,9 @@ class GenericResizeConverter : public OpRewritePattern { SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto emptyTensor = b.create(resultTy.getShape(), resultETy, + auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(), resultETy, *dynamicDimsOr); - auto genericOp = b.create( + auto genericOp = linalg::GenericOp::create(b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); Value resize = genericOp.getResult(0); @@ -1800,19 +1800,19 @@ class GenericResizeConverter : public OpRewritePattern { OpBuilder::InsertionGuard regionGuard(b); b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({resultETy}), loc); - Value batch = b.create(0); - Value y = b.create(1); - Value x = b.create(2); - Value channel = b.create(3); + Value batch = linalg::IndexOp::create(b, 0); + Value y = linalg::IndexOp::create(b, 1); + Value x = linalg::IndexOp::create(b, 2); + Value channel = linalg::IndexOp::create(b, 3); Value zeroI32 = - b.create(b.getZeroAttr(b.getI32Type())); - Value zeroFp = b.create(b.getZeroAttr(floatTy)); - Value hMax = b.create(b.getI32IntegerAttr(imageH - 1)); - Value wMax = b.create(b.getI32IntegerAttr(imageW - 1)); + arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type())); + Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy)); + Value hMax = arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1)); + Value wMax = arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1)); - Value inY = b.create(b.getI32Type(), y); - Value inX = b.create(b.getI32Type(), x); + Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y); + Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x); SmallVector scale, offset, border; if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || @@ -1824,16 +1824,16 @@ class GenericResizeConverter : public OpRewritePattern { } Value yScaleN, yScaleD, xScaleN, xScaleD; - yScaleN = b.create(b.getI32IntegerAttr(scale[0])); - yScaleD = b.create(b.getI32IntegerAttr(scale[1])); - xScaleN = b.create(b.getI32IntegerAttr(scale[2])); - xScaleD = b.create(b.getI32IntegerAttr(scale[3])); + yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0])); + yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1])); + xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2])); + xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3])); Value yOffset, xOffset, yBorder, xBorder; - yOffset = b.create(b.getI32IntegerAttr(offset[0])); - xOffset = b.create(b.getI32IntegerAttr(offset[1])); - yBorder = b.create(b.getI32IntegerAttr(border[0])); - xBorder = b.create(b.getI32IntegerAttr(border[1])); + yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0])); + xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1])); + yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0])); + xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1])); // Compute the ix and dx values for both the X and Y dimensions. auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, @@ -1846,16 +1846,16 @@ class GenericResizeConverter : public OpRewritePattern { } // x = x * scale_d + offset; // ix = floor(x / scale_n) - Value val = b.create(in, scaleD); - val = b.create(val, offset); - index = b.create(val, scaleN); + Value val = arith::MulIOp::create(b, in, scaleD); + val = arith::AddIOp::create(b, val, offset); + index = arith::FloorDivSIOp::create(b, val, scaleN); // rx = x % scale_n // dx = rx / scale_n - Value r = b.create(val, scaleN); - Value rFp = b.create(floatTy, r); - Value scaleNfp = b.create(floatTy, scaleN); - delta = b.create(rFp, scaleNfp); + Value r = arith::RemSIOp::create(b, val, scaleN); + Value rFp = arith::SIToFPOp::create(b, floatTy, r); + Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN); + delta = arith::DivFOp::create(b, rFp, scaleNfp); }; // Compute the ix and dx values for the X and Y dimensions - int case. @@ -1870,11 +1870,11 @@ class GenericResizeConverter : public OpRewritePattern { // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x - ix * scale_n; - Value val = b.create(in, scaleD); - val = b.create(val, offset); - index = b.create(val, scaleN); - delta = b.create(index, scaleN); - delta = b.create(val, delta); + Value val = arith::MulIOp::create(b, in, scaleD); + val = arith::AddIOp::create(b, val, offset); + index = arith::DivSIOp::create(b, val, scaleN); + delta = arith::MulIOp::create(b, index, scaleN); + delta = arith::SubIOp::create(b, val, delta); }; Value ix, iy, dx, dy; @@ -1887,54 +1887,54 @@ class GenericResizeConverter : public OpRewritePattern { } if (op.getMode() == "NEAREST_NEIGHBOR") { - auto one = b.create(b.getI32IntegerAttr(1)); + auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, Value max, int size, ImplicitLocOpBuilder &b) -> Value { if (size == 1) { - return b.create(0); + return arith::ConstantIndexOp::create(b, 0); } Value pred; if (floatingPointMode) { - auto h = b.create(b.getFloatAttr(floatTy, 0.5f)); - pred = b.create(arith::CmpFPredicate::OGE, dval, h); + auto h = arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f)); + pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h); } else { - Value dvalDouble = b.create(dval, one); - pred = b.create(arith::CmpIPredicate::sge, + Value dvalDouble = arith::ShLIOp::create(b, dval, one); + pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge, dvalDouble, scale); } - auto offset = b.create(pred, one, zeroI32); - val = b.create(val, offset); + auto offset = arith::SelectOp::create(b, pred, one, zeroI32); + val = arith::AddIOp::create(b, val, offset); val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false); - return b.create(b.getIndexType(), val); + return arith::IndexCastOp::create(b, b.getIndexType(), val); }; iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); - Value result = b.create( + Value result = tensor::ExtractOp::create(b, input, ValueRange{batch, iy, ix, channel}); - b.create(result); + linalg::YieldOp::create(b, result); } else { // The mode here must be BILINEAR. assert(op.getMode() == "BILINEAR"); - auto oneVal = b.create(b.getI32IntegerAttr(1)); + auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, Value max, ImplicitLocOpBuilder &b) { val0 = in; - val1 = b.create(val0, oneVal); + val1 = arith::AddIOp::create(b, val0, oneVal); val0 = clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false); val1 = clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false); - val0 = b.create(b.getIndexType(), val0); - val1 = b.create(b.getIndexType(), val1); + val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0); + val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1); }; // Linalg equivalent to the section below: @@ -1946,27 +1946,27 @@ class GenericResizeConverter : public OpRewritePattern { getClampedIdxs(y0, y1, imageH, iy, hMax, b); getClampedIdxs(x0, x1, imageW, ix, wMax, b); - Value y0x0 = b.create( + Value y0x0 = tensor::ExtractOp::create(b, input, ValueRange{batch, y0, x0, channel}); - Value y0x1 = b.create( + Value y0x1 = tensor::ExtractOp::create(b, input, ValueRange{batch, y0, x1, channel}); - Value y1x0 = b.create( + Value y1x0 = tensor::ExtractOp::create(b, input, ValueRange{batch, y1, x0, channel}); - Value y1x1 = b.create( + Value y1x1 = tensor::ExtractOp::create(b, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { auto oneVal = - b.create(b.getFloatAttr(floatTy, 1.0f)); + arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return val0; - Value oneMinusDelta = b.create(oneVal, delta); - Value mul0 = b.create(val0, oneMinusDelta); - Value mul1 = b.create(val1, delta); - return b.create(mul0, mul1); + Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta); + Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta); + Value mul1 = arith::MulFOp::create(b, val1, delta); + return arith::AddFOp::create(b, mul0, mul1); }; // Linalg equivalent to the section below: @@ -1982,18 +1982,18 @@ class GenericResizeConverter : public OpRewritePattern { // Linalg equivalent to the section below: // result = topAcc * (unit_y - dy) + bottomAcc * dy Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); - b.create(result); + linalg::YieldOp::create(b, result); } else { // Perform in quantized space. - y0x0 = b.create(resultETy, y0x0); - y0x1 = b.create(resultETy, y0x1); - y1x0 = b.create(resultETy, y1x0); - y1x1 = b.create(resultETy, y1x1); + y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0); + y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1); + y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0); + y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1); const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { - dx = b.create(resultETy, dx); - dy = b.create(resultETy, dy); + dx = arith::ExtSIOp::create(b, resultETy, dx); + dy = arith::ExtSIOp::create(b, resultETy, dy); } Value yScaleNExt = yScaleN; @@ -2002,26 +2002,26 @@ class GenericResizeConverter : public OpRewritePattern { const int64_t scaleBitwidth = xScaleN.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { - yScaleNExt = b.create(resultETy, yScaleN); - xScaleNExt = b.create(resultETy, xScaleN); + yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN); + xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN); } auto interpolate = [](Value val0, Value val1, Value weight1, Value scale, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) - return b.create(val0, scale); - Value weight0 = b.create(scale, weight1); - Value mul0 = b.create(val0, weight0); - Value mul1 = b.create(val1, weight1); - return b.create(mul0, mul1); + return arith::MulIOp::create(b, val0, scale); + Value weight0 = arith::SubIOp::create(b, scale, weight1); + Value mul0 = arith::MulIOp::create(b, val0, weight0); + Value mul1 = arith::MulIOp::create(b, val1, weight1); + return arith::AddIOp::create(b, mul0, mul1); }; Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); Value result = interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); - b.create(result); + linalg::YieldOp::create(b, result); } } } @@ -2072,11 +2072,11 @@ class ReverseConverter : public OpRewritePattern { SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } - Value axisDimSize = rewriter.create(loc, input, axis); + Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. auto emptyTensor = rewriter @@ -2094,21 +2094,21 @@ class ReverseConverter : public OpRewritePattern { llvm::SmallVector indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { Value index = - rewriter.create(nestedLoc, i).getResult(); + linalg::IndexOp::create(rewriter, nestedLoc, i).getResult(); if (i == axis) { - auto one = rewriter.create(nestedLoc, 1); + auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1); auto sizeMinusOne = - rewriter.create(nestedLoc, axisDimSize, one); - index = rewriter.create(nestedLoc, sizeMinusOne, + arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one); + index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne, index); } indices.push_back(index); } - auto extract = nestedBuilder.create( + auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc, input, indices); - nestedBuilder.create(op.getLoc(), + linalg::YieldOp::create(nestedBuilder, op.getLoc(), extract.getResult()); }); return success(); @@ -2148,11 +2148,11 @@ struct TileConverter : public OpConversionPattern { SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) || multiples[i] == -1) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } - auto emptyTensor = rewriter.create( + auto emptyTensor = tensor::EmptyOp::create(rewriter, op.getLoc(), genericShape, elementTy, dynDims); // We needs to map the input shape to the non-broadcasted dimensions. @@ -2168,12 +2168,12 @@ struct TileConverter : public OpConversionPattern { SmallVector affineMaps = { readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(op.getLoc(), *args.begin()); + linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin()); }); auto shapeValue = getTosaConstShape( @@ -2220,7 +2220,7 @@ class ArgMaxConverter : public OpRewritePattern { SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { - dynDims.push_back(rewriter.create(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -2229,7 +2229,7 @@ class ArgMaxConverter : public OpRewritePattern { .create(loc, resultTy.getShape(), outElementTy, dynDims) .getResult(); - auto fillValueIdx = rewriter.create( + auto fillValueIdx = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter @@ -2250,7 +2250,7 @@ class ArgMaxConverter : public OpRewritePattern { argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = - rewriter.create(loc, fillValueMaxAttr); + arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = rewriter .create(loc, ValueRange{fillValueMax}, @@ -2274,7 +2274,7 @@ class ArgMaxConverter : public OpRewritePattern { bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, rewriter.getContext()); - auto linalgOp = rewriter.create( + auto linalgOp = linalg::GenericOp::create(rewriter, loc, ArrayRef({resultTy, resultMaxTy}), input, ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, @@ -2283,41 +2283,41 @@ class ArgMaxConverter : public OpRewritePattern { auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; - Value newIndex = rewriter.create( + Value newIndex = arith::IndexCastOp::create(rewriter, nestedLoc, oldIndex.getType(), - rewriter.create(loc, axis)); + linalg::IndexOp::create(rewriter, loc, axis)); Value predicate; if (isa(inElementTy)) { if (argmaxOp.getNanMode() == "IGNORE") { // Only update index & max value for non NaN values. If all // values are NaNs, the initial index will be return which is 0. - predicate = rewriter.create( + predicate = arith::CmpFOp::create(rewriter, nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else { // Update max value if either of the following is true: // - new value is bigger // - cur max is not NaN and new value is NaN - Value gt = rewriter.create( + Value gt = arith::CmpFOp::create(rewriter, nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue); - Value oldNonNaN = rewriter.create( + Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue); - predicate = rewriter.create( + predicate = arith::AndIOp::create(rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); } } else if (isa(inElementTy)) { - predicate = rewriter.create( + predicate = arith::CmpIOp::create(rewriter, nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; } - auto resultMax = rewriter.create( + auto resultMax = arith::SelectOp::create(rewriter, nestedLoc, predicate, newValue, oldValue); - auto resultIndex = rewriter.create( + auto resultIndex = arith::SelectOp::create(rewriter, nestedLoc, predicate, newIndex, oldIndex); - nestedBuilder.create( + linalg::YieldOp::create(nestedBuilder, nestedLoc, ValueRange({resultIndex, resultMax})); }); @@ -2363,19 +2363,19 @@ class GatherConverter : public OpConversionPattern { rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, ArrayRef({resultTy}), ValueRange{indices}, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; - auto index0 = rewriter.create(loc, 0); - Value index1 = rewriter.create( + auto index0 = linalg::IndexOp::create(rewriter, loc, 0); + Value index1 = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), indexValue); - auto index2 = rewriter.create(loc, 2); - Value extract = rewriter.create( + auto index2 = linalg::IndexOp::create(rewriter, loc, 2); + Value extract = tensor::ExtractOp::create(rewriter, loc, input, ValueRange{index0, index1, index2}); - rewriter.create(loc, extract); + linalg::YieldOp::create(rewriter, loc, extract); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); @@ -2424,7 +2424,7 @@ class TableConverter : public OpRewritePattern { for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { dynDims.push_back( - rewriter.create(loc, op.getOperand(0), i)); + tensor::DimOp::create(rewriter, loc, op.getOperand(0), i)); } } @@ -2437,7 +2437,7 @@ class TableConverter : public OpRewritePattern { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -2452,69 +2452,69 @@ class TableConverter : public OpRewritePattern { rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { - Value index = rewriter.create( + Value index = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), inputValue); - Value offset = rewriter.create(loc, 128); - index = rewriter.create(loc, rewriter.getIndexType(), + Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128); + index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), index, offset); Value extract = - rewriter.create(loc, table, ValueRange{index}); - rewriter.create(loc, extract); + tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); + linalg::YieldOp::create(rewriter, loc, extract); return success(); } if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { - Value extend = rewriter.create( + Value extend = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), inputValue); - auto offset = rewriter.create( + auto offset = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(32768)); - auto seven = rewriter.create( + auto seven = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(7)); - auto one = rewriter.create( + auto one = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1)); - auto b1111111 = rewriter.create( + auto b1111111 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value - auto extendAdd = rewriter.create(loc, extend, offset); - Value index = rewriter.create(loc, extendAdd, seven); + auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset); + Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven); Value fraction = - rewriter.create(loc, extendAdd, b1111111); + arith::AndIOp::create(rewriter, loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; - Value indexPlusOne = rewriter.create(loc, index, one); + Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one); - index = rewriter.create( + index = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), index); - indexPlusOne = rewriter.create( + indexPlusOne = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), indexPlusOne); Value base = - rewriter.create(loc, table, ValueRange{index}); - Value next = rewriter.create( + tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); + Value next = tensor::ExtractOp::create(rewriter, loc, table, ValueRange{indexPlusOne}); base = - rewriter.create(loc, rewriter.getI32Type(), base); + arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base); next = - rewriter.create(loc, rewriter.getI32Type(), next); + arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction - Value baseScaled = rewriter.create(loc, base, seven); - Value diff = rewriter.create(loc, next, base); - Value diffScaled = rewriter.create(loc, diff, fraction); + Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven); + Value diff = arith::SubIOp::create(rewriter, loc, next, base); + Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction); Value result = - rewriter.create(loc, baseScaled, diffScaled); + arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled); - rewriter.create(loc, result); + linalg::YieldOp::create(rewriter, loc, result); return success(); } @@ -2532,8 +2532,8 @@ struct RFFT2dConverter final : public OpRewritePattern { static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, OpFoldResult ofr) { - auto one = builder.create(loc, 1); - auto two = builder.create(loc, 2); + auto one = arith::ConstantIndexOp::create(builder, loc, 1); + auto two = arith::ConstantIndexOp::create(builder, loc, 2); auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); auto divBy2 = builder.createOrFold(loc, value, two); @@ -2562,9 +2562,9 @@ struct RFFT2dConverter final : public OpRewritePattern { RankedTensorType type, llvm::ArrayRef dynamicSizes) { auto emptyTensor = - rewriter.create(loc, type, dynamicSizes); + tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) @@ -2574,18 +2574,18 @@ struct RFFT2dConverter final : public OpRewritePattern { static Value castIndexToFloat(OpBuilder &builder, Location loc, FloatType type, Value value) { - auto integerVal = builder.create( + auto integerVal = arith::IndexCastUIOp::create(builder, loc, type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() : builder.getI32Type(), value); - return builder.create(loc, type, integerVal); + return arith::UIToFPOp::create(builder, loc, type, integerVal); } static Value createLinalgIndex(OpBuilder &builder, Location loc, FloatType type, int64_t index) { - auto indexVal = builder.create(loc, index); + auto indexVal = linalg::IndexOp::create(builder, loc, index); return castIndexToFloat(builder, loc, type, indexVal); } @@ -2640,7 +2640,7 @@ struct RFFT2dConverter final : public OpRewritePattern { // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); - auto twoPi = rewriter.create(loc, twoPiAttr); + auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); @@ -2650,43 +2650,43 @@ struct RFFT2dConverter final : public OpRewritePattern { Value sumImag = args[2]; // Indices for angle computation - Value oy = builder.create(loc, 1); - Value ox = builder.create(loc, 2); - Value iy = builder.create(loc, 3); - Value ix = builder.create(loc, 4); + Value oy = linalg::IndexOp::create(builder, loc, 1); + Value ox = linalg::IndexOp::create(builder, loc, 2); + Value iy = linalg::IndexOp::create(builder, loc, 3); + Value ix = linalg::IndexOp::create(builder, loc, 4); // Calculating angle without integer parts of components as sin/cos are // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) // / W); - auto iyXoy = builder.create(loc, iy, oy); - auto ixXox = builder.create(loc, ix, ox); + auto iyXoy = index::MulOp::create(builder, loc, iy, oy); + auto ixXox = index::MulOp::create(builder, loc, ix, ox); - auto iyRem = builder.create(loc, iyXoy, dimH); - auto ixRem = builder.create(loc, ixXox, dimW); + auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); + auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); - auto yComponent = builder.create(loc, iyRemFloat, constH); - auto xComponent = builder.create(loc, ixRemFloat, constW); - auto sumXY = builder.create(loc, yComponent, xComponent); - auto angle = builder.create(loc, twoPi, sumXY); + auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); + auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); + auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); + auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); // realComponent = valReal * cos(angle) // imagComponent = valReal * sin(angle) - auto cosAngle = builder.create(loc, angle); - auto sinAngle = builder.create(loc, angle); + auto cosAngle = math::CosOp::create(builder, loc, angle); + auto sinAngle = math::SinOp::create(builder, loc, angle); auto realComponent = - builder.create(loc, valReal, cosAngle); + arith::MulFOp::create(builder, loc, valReal, cosAngle); auto imagComponent = - builder.create(loc, valReal, sinAngle); + arith::MulFOp::create(builder, loc, valReal, sinAngle); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent - auto outReal = builder.create(loc, sumReal, realComponent); - auto outImag = builder.create(loc, sumImag, imagComponent); + auto outReal = arith::AddFOp::create(builder, loc, sumReal, realComponent); + auto outImag = arith::SubFOp::create(builder, loc, sumImag, imagComponent); - builder.create(loc, ValueRange{outReal, outImag}); + linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( @@ -2760,7 +2760,7 @@ struct FFT2dConverter final : OpRewritePattern { // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); - auto twoPi = rewriter.create(loc, twoPiAttr); + auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); Value constH = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); Value constW = @@ -2773,57 +2773,57 @@ struct FFT2dConverter final : OpRewritePattern { Value sumImag = args[3]; // Indices for angle computation - Value oy = builder.create(loc, 1); - Value ox = builder.create(loc, 2); - Value iy = builder.create(loc, 3); - Value ix = builder.create(loc, 4); + Value oy = linalg::IndexOp::create(builder, loc, 1); + Value ox = linalg::IndexOp::create(builder, loc, 2); + Value iy = linalg::IndexOp::create(builder, loc, 3); + Value ix = linalg::IndexOp::create(builder, loc, 4); // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * // ox) % W ) / W); - auto iyXoy = builder.create(loc, iy, oy); - auto ixXox = builder.create(loc, ix, ox); + auto iyXoy = index::MulOp::create(builder, loc, iy, oy); + auto ixXox = index::MulOp::create(builder, loc, ix, ox); - auto iyRem = builder.create(loc, iyXoy, dimH); - auto ixRem = builder.create(loc, ixXox, dimW); + auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); + auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); auto ixRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); - auto yComponent = builder.create(loc, iyRemFloat, constH); - auto xComponent = builder.create(loc, ixRemFloat, constW); + auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); + auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); - auto sumXY = builder.create(loc, yComponent, xComponent); - auto angle = builder.create(loc, twoPi, sumXY); + auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); + auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); if (inverse.getValue()) { - angle = builder.create( + angle = arith::MulFOp::create(builder, loc, angle, - rewriter.create( + arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(real_el_ty, -1.0))); } // realComponent = val_real * cos(a) + val_imag * sin(a); // imagComponent = -val_real * sin(a) + val_imag * cos(a); - auto cosAngle = builder.create(loc, angle); - auto sinAngle = builder.create(loc, angle); + auto cosAngle = math::CosOp::create(builder, loc, angle); + auto sinAngle = math::SinOp::create(builder, loc, angle); - auto rcos = builder.create(loc, valReal, cosAngle); - auto rsin = builder.create(loc, valImag, sinAngle); - auto realComponent = builder.create(loc, rcos, rsin); + auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle); + auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle); + auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin); - auto icos = builder.create(loc, valImag, cosAngle); - auto isin = builder.create(loc, valReal, sinAngle); + auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle); + auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle); - auto imagComponent = builder.create(loc, icos, isin); + auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent - auto outReal = builder.create(loc, sumReal, realComponent); - auto outImag = builder.create(loc, sumImag, imagComponent); + auto outReal = arith::AddFOp::create(builder, loc, sumReal, realComponent); + auto outImag = arith::AddFOp::create(builder, loc, sumImag, imagComponent); - builder.create(loc, ValueRange{outReal, outImag}); + linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index b89fde4fbc17e..d81b8645506c4 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -52,9 +52,9 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, highIndices.push_back(rewriter.getIndexAttr(highPad)); } - Value padValue = rewriter.create(loc, padAttr); + Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr); - return rewriter.create( + return tensor::PadOp::create(rewriter, loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices, highIndices, padValue); } @@ -72,10 +72,10 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value biasVal = args[0]; Type resType = args[1].getType(); if (resType != biasVal.getType()) { - biasVal = builder.create(loc, resType, biasVal); + biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); } - Value added = builder.create(loc, biasVal, args[1]); - builder.create(loc, added); + Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); }) .getResult(0); } @@ -134,19 +134,19 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, if (resType != biasVal.getType()) { biasVal = resultTy.getElementType().isFloat() - ? builder.create(loc, resType, biasVal) + ? arith::ExtFOp::create(builder, loc, resType, biasVal) .getResult() - : builder.create(loc, resType, biasVal) + : arith::ExtSIOp::create(builder, loc, resType, biasVal) .getResult(); } - builder.create(loc, biasVal); + linalg::YieldOp::create(builder, loc, biasVal); }) .getResult(0); } static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder) { - return builder.create(attr); + return arith::ConstantIndexOp::create(builder, attr); } // Calculating the output width/height using the formula: @@ -160,22 +160,22 @@ static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t dilationAttr, OpBuilder &rewriter) { ImplicitLocOpBuilder builder(loc, rewriter); - auto one = rewriter.create( + auto one = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(inputDim.getType(), 1)); Value padBefore = reifyConstantDim(padBeforeAttr, builder); - Value paddedBefore = builder.create(inputDim, padBefore); + Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore); Value padAfter = reifyConstantDim(padAfterAttr, builder); - Value paddedAfter = builder.create(paddedBefore, padAfter); + Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter); - Value subOne = builder.create(kernelDim, one); + Value subOne = arith::SubIOp::create(builder, kernelDim, one); Value dilation = reifyConstantDim(dilationAttr, builder); - Value dilated = builder.create(dilation, subOne); - Value addOne = builder.create(dilated, one); + Value dilated = arith::MulIOp::create(builder, dilation, subOne); + Value addOne = arith::AddIOp::create(builder, dilated, one); - Value subtract = builder.create(paddedAfter, addOne); + Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne); Value stride = reifyConstantDim(strideAttr, builder); - Value divide = builder.create(subtract, stride); - return builder.create(divide, one); + Value divide = arith::DivUIOp::create(builder, subtract, stride); + return arith::AddIOp::create(builder, divide, one); } // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D @@ -198,9 +198,9 @@ static SmallVector inferDynamicDimsForConv( auto padBottom = padAttr[i * 2 + 1]; auto stride = strideAttr[i]; auto dilation = dilationAttr[i]; - Value initDynDim = rewriter.create(loc, input, inputDim); + Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim); Value kernelDynDim = - rewriter.create(loc, weight, kernelDim); + tensor::DimOp::create(rewriter, loc, weight, kernelDim); // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) dynDims[inputDim] = getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom, @@ -211,7 +211,7 @@ static SmallVector inferDynamicDimsForConv( // Get the batch/channels dimensions. for (int i = 0; i < inputRank; i++) { if (resultTy.isDynamicDim(i) && !dynDims[i]) - dynDims[i] = rewriter.create(loc, input, i); + dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i); } SmallVector filteredDims = condenseValues(dynDims); @@ -350,7 +350,7 @@ class ConvConverter : public OpConversionPattern { auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create(loc, newWeightTy, weight, + weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight, weightPermAttr); } } @@ -372,7 +372,7 @@ class ConvConverter : public OpConversionPattern { auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create(loc, newWeightTy, weight, + weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight, weightPermAttr); } @@ -384,7 +384,7 @@ class ConvConverter : public OpConversionPattern { auto strideAttr = rewriter.getI64TensorAttr(stride); auto dilationAttr = rewriter.getI64TensorAttr(dilation); - Value biasEmptyTensor = rewriter.create( + Value biasEmptyTensor = tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), accETy, filteredDims); Value broadcastBias = @@ -394,8 +394,8 @@ class ConvConverter : public OpConversionPattern { auto iZp = rewriter.getI32IntegerAttr(inputZpVal); auto kZp = rewriter.getI32IntegerAttr(weightZpVal); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); + auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); Value conv = rewriter @@ -417,7 +417,7 @@ class ConvConverter : public OpConversionPattern { // We may need to truncate back to the result type if the accumulator was // wider than the result. if (resultTy != accTy) - conv = rewriter.create(loc, resultTy, conv); + conv = tosa::CastOp::create(rewriter, loc, resultTy, conv); rewriter.replaceOp(op, conv); return success(); @@ -526,15 +526,15 @@ class DepthwiseConvConverter accETy); auto resultZeroAttr = rewriter.getZeroAttr(accETy); - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, ValueRange{emptyTensor}) .result(); - Value biasEmptyTensor = rewriter.create( + Value biasEmptyTensor = tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), resultETy, filteredDims); // Broadcast the initial value to the output tensor before convolving. @@ -553,7 +553,7 @@ class DepthwiseConvConverter // We may need to truncate back to the result type if the accumulator was // wider than the result. if (accETy != resultETy) - conv = rewriter.create( + conv = tosa::CastOp::create(rewriter, loc, RankedTensorType::get(cast(conv.getType()).getShape(), resultETy), @@ -561,7 +561,7 @@ class DepthwiseConvConverter SmallVector reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); - Value convReshape = rewriter.create( + Value convReshape = tensor::CollapseShapeOp::create(rewriter, loc, resultTy, conv, reassociationMap); Value result = @@ -574,20 +574,20 @@ class DepthwiseConvConverter ValueRange args) { Value added; if (llvm::isa(inputETy)) - added = nestedBuilder.create(loc, args[0], + added = arith::AddFOp::create(nestedBuilder, loc, args[0], args[1]); else - added = nestedBuilder.create(loc, args[0], + added = arith::AddIOp::create(nestedBuilder, loc, args[0], args[1]); - nestedBuilder.create(nestedLoc, added); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); }) .getResult(0); rewriter.replaceOp(op, result); } else { IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal); IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, wZp); + auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); + auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); Value conv = rewriter .create( @@ -596,7 +596,7 @@ class DepthwiseConvConverter .getResult(0); SmallVector reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); - Value convReshape = rewriter.create( + Value convReshape = tensor::CollapseShapeOp::create(rewriter, loc, resultTy, conv, reassociationMap); Value result = linalgIntBroadcastExtSIAdd( rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps); @@ -621,22 +621,22 @@ class MatMulConverter : public OpConversionPattern { dynDims.resize(cast(op->getResult(0).getType()).getRank()); if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) { - dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); + dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0); } if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) { - dynDims[1] = rewriter.create(loc, op->getOperand(0), 1); + dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1); } if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) { - dynDims[2] = rewriter.create(loc, op->getOperand(1), 2); + dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2); } SmallVector filteredDims = condenseValues(dynDims); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); - Value zero = rewriter.create(loc, zeroAttr); - auto emptyTensor = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, @@ -670,9 +670,9 @@ class MatMulConverter : public OpConversionPattern { return success(); } - auto aZp = rewriter.create( + auto aZp = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(aZpVal)); - auto bZp = rewriter.create( + auto bZp = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(bZpVal)); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, @@ -702,7 +702,7 @@ class MaxPool2dConverter : public OpConversionPattern { // Batch dimension if (resultTy.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); // Height/width dimensions for (int64_t dim : {1, 2}) { @@ -713,10 +713,10 @@ class MaxPool2dConverter : public OpConversionPattern { int64_t index = dim - 1; // Input height/width - Value ihw = rewriter.create(loc, input, dim); + Value ihw = tensor::DimOp::create(rewriter, loc, input, dim); // Kernel height/width - Value khw = rewriter.create(loc, kernel[index]); + Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]); // Output height/width Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2], @@ -727,7 +727,7 @@ class MaxPool2dConverter : public OpConversionPattern { // Channel dimension if (resultTy.isDynamicDim(3)) - dynamicDims.push_back(rewriter.create(loc, input, 3)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3)); return dynamicDims; } @@ -776,7 +776,7 @@ class MaxPool2dConverter : public OpConversionPattern { Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr); ArrayRef kernel = op.getKernel(); ArrayRef stride = op.getStride(); @@ -785,15 +785,15 @@ class MaxPool2dConverter : public OpConversionPattern { Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); Value filledEmptyTensor = - rewriter.create(loc, initialValue, emptyTensor) + linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor) .result(); Value fakeWindowDims = - rewriter.create(loc, kernel, resultETy); + tensor::EmptyOp::create(rewriter, loc, kernel, resultETy); if (isUnsigned) { rewriter.replaceOpWithNewOp( @@ -802,7 +802,7 @@ class MaxPool2dConverter : public OpConversionPattern { return llvm::success(); } - auto resultOp = rewriter.create( + auto resultOp = linalg::PoolingNhwcMaxOp::create(rewriter, op->getLoc(), ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr); @@ -821,7 +821,7 @@ class MaxPool2dConverter : public OpConversionPattern { // it to include the appropriate checks. If the current value is NaN the // old value of pool will be taken otherwise we use the result. if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") { - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, op->getLoc(), resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(), resultOp.getIndexingMapsArray(), resultOp.getIteratorTypesArray(), @@ -832,12 +832,12 @@ class MaxPool2dConverter : public OpConversionPattern { auto &oldMaxOp = *resultOp.getBlock()->begin(); map.map(oldArgs, blockArgs); auto *newOp = opBuilder.clone(oldMaxOp, map); - Value isNaN = opBuilder.create( + Value isNaN = arith::CmpFOp::create(opBuilder, op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(), blockArgs.front()); - auto selectOp = opBuilder.create( + auto selectOp = arith::SelectOp::create(opBuilder, op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0)); - opBuilder.create(loc, selectOp.getResult()); + linalg::YieldOp::create(opBuilder, loc, selectOp.getResult()); }); rewriter.replaceOp(resultOp, genericOp); } @@ -893,7 +893,7 @@ class AvgPool2dConverter : public OpRewritePattern { Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); auto initialAttr = rewriter.getZeroAttr(accETy); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr); ArrayRef kernel = op.getKernel(); ArrayRef stride = op.getStride(); @@ -902,7 +902,7 @@ class AvgPool2dConverter : public OpRewritePattern { Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value poolEmptyTensor = rewriter.create( + Value poolEmptyTensor = tensor::EmptyOp::create(rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = @@ -912,7 +912,7 @@ class AvgPool2dConverter : public OpRewritePattern { .result(); Value fakeWindowDims = - rewriter.create(loc, kernel, accETy); + tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. Value poolingOp = rewriter @@ -924,24 +924,24 @@ class AvgPool2dConverter : public OpRewritePattern { // Normalize the summed value by the number of elements grouped in each // pool. - Value iH = rewriter.create(loc, poolingOp, 1); - Value iW = rewriter.create(loc, poolingOp, 2); + Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1); + Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2); - auto one = rewriter.create(loc, 1); - iH = rewriter.create(loc, iH, one); - iW = rewriter.create(loc, iW, one); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); + iH = arith::SubIOp::create(rewriter, loc, iH, one); + iW = arith::SubIOp::create(rewriter, loc, iW, one); - Value genericEmptyTensor = rewriter.create( + Value genericEmptyTensor = tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), resultETy, dynamicDims); auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, ArrayRef({resultTy}), ValueRange{poolingOp}, ValueRange{genericEmptyTensor}, ArrayRef({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // Determines what the portion of valid input is covered by the // kernel. @@ -949,30 +949,30 @@ class AvgPool2dConverter : public OpRewritePattern { if (pad == 0) return valid; - auto padVal = rewriter.create(loc, pad); - Value dpos = rewriter.create(loc, pos, padVal); + auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad); + Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal); - Value offset = rewriter.create(loc, dpos, zero); - return rewriter.create(loc, valid, offset) + Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero); + return arith::AddIOp::create(rewriter, loc, valid, offset) ->getResult(0); }; auto coverageFn = [&](int64_t i, Value isize) -> Value { Value strideVal = - rewriter.create(loc, stride[i - 1]); + arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]); Value val = - rewriter.create(loc, kernel[i - 1]); + arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]); // Find the position relative to the input tensor's ends. - Value left = rewriter.create(loc, i); - Value right = rewriter.create(loc, isize, left); - left = rewriter.create(loc, left, strideVal); - right = rewriter.create(loc, right, strideVal); + Value left = linalg::IndexOp::create(rewriter, loc, i); + Value right = arith::SubIOp::create(rewriter, loc, isize, left); + left = arith::MulIOp::create(rewriter, loc, left, strideVal); + right = arith::MulIOp::create(rewriter, loc, right, strideVal); // Determine how much padding was included. val = padFn(val, left, pad[i * 2]); val = padFn(val, right, pad[i * 2 + 1]); - return rewriter.create(loc, one, val); + return arith::MaxSIOp::create(rewriter, loc, one, val); }; // Compute the indices from either end. @@ -980,70 +980,70 @@ class AvgPool2dConverter : public OpRewritePattern { Value kW3 = coverageFn(2, iW); // Compute the total number of elements and normalize. - auto count = rewriter.create( + auto count = arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), - rewriter.create(loc, kH3, kW3)); + arith::MulIOp::create(rewriter, loc, kH3, kW3)); // Divide by the number of summed values. For floats this is just // a div however for quantized values input normalization had // to be applied. Value poolVal = args[0]; if (isa(accETy)) { - auto countF = rewriter.create(loc, accETy, count); - poolVal = rewriter.create(loc, poolVal, countF) + auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count); + poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF) ->getResult(0); if (accETy.getIntOrFloatBitWidth() > resultETy.getIntOrFloatBitWidth()) poolVal = - rewriter.create(loc, resultETy, poolVal); + arith::TruncFOp::create(rewriter, loc, resultETy, poolVal); } else { // If we have quantization information we need to apply an offset // for the input zp value. if (inputZpVal != 0) { - auto inputZp = rewriter.create( + auto inputZp = arith::ConstantOp::create(rewriter, loc, b.getIntegerAttr(accETy, inputZpVal)); Value offset = - rewriter.create(loc, accETy, count, inputZp); + arith::MulIOp::create(rewriter, loc, accETy, count, inputZp); poolVal = - rewriter.create(loc, accETy, poolVal, offset); + arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset); } // Compute: k = 32 - count_leading_zeros(value - 1) - Value one32 = rewriter.create( + Value one32 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1)); - Value thirtyTwo32 = rewriter.create( + Value thirtyTwo32 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(32)); Value countSubOne = - rewriter.create(loc, count, one32); + arith::SubIOp::create(rewriter, loc, count, one32); Value leadingZeros = - rewriter.create(loc, countSubOne); + math::CountLeadingZerosOp::create(rewriter, loc, countSubOne); Value k = - rewriter.create(loc, thirtyTwo32, leadingZeros); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros); // Compute: numerator = ((1 << 30) + 1) << k Value k64 = - rewriter.create(loc, rewriter.getI64Type(), k); - Value thirtyShiftPlusOne = rewriter.create( + arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k); + Value thirtyShiftPlusOne = arith::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); Value numerator = - rewriter.create(loc, thirtyShiftPlusOne, k64); + arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64); // Compute: scale.multiplier = numerator / value; - Value count64 = rewriter.create( + Value count64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), count); Value multiplier = - rewriter.create(loc, numerator, count64); - multiplier = rewriter.create( + arith::DivUIOp::create(rewriter, loc, numerator, count64); + multiplier = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), multiplier); // Compute: scale.shift = 30 + k Value k8 = - rewriter.create(loc, rewriter.getI8Type(), k); - Value thirty8 = rewriter.create( + arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k); + Value thirty8 = arith::ConstantOp::create(rewriter, loc, rewriter.getI8IntegerAttr(30)); - Value shift = rewriter.create(loc, k8, thirty8); + Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = rewriter @@ -1055,19 +1055,19 @@ class AvgPool2dConverter : public OpRewritePattern { // If we have quantization information we need to apply output // zeropoint. if (outputZpVal != 0) { - auto outputZp = rewriter.create( + auto outputZp = arith::ConstantOp::create(rewriter, loc, b.getIntegerAttr(scaled.getType(), outputZpVal)); - scaled = rewriter.create(loc, scaled, outputZp) + scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp) .getResult(); } // Apply Clip. int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); - auto min = rewriter.create( + auto min = arith::ConstantIntOp::create(rewriter, loc, accETy, APInt::getSignedMinValue(outBitwidth).getSExtValue()); - auto max = rewriter.create( + auto max = arith::ConstantIntOp::create(rewriter, loc, accETy, APInt::getSignedMaxValue(outBitwidth).getSExtValue()); auto clamp = clampIntHelper(loc, scaled, min, max, rewriter, @@ -1077,11 +1077,11 @@ class AvgPool2dConverter : public OpRewritePattern { // Convert type. if (resultETy != clamp.getType()) { poolVal = - rewriter.create(loc, resultETy, poolVal); + arith::TruncIOp::create(rewriter, loc, resultETy, poolVal); } } - rewriter.create(loc, poolVal); + linalg::YieldOp::create(rewriter, loc, poolVal); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -1106,7 +1106,7 @@ class TransposeConverter : public OpRewritePattern { auto permutedSizes = applyTOSAPermutation(inputSizes, constantPerms); - auto permutedInit = rewriter.create( + auto permutedInit = tensor::EmptyOp::create(rewriter, loc, permutedSizes, op.getInput1().getType().getElementType()); rewriter.replaceOpWithNewOp( op, op.getInput1(), permutedInit, diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index 7dbccd19a0518..cd36aadce410a 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -27,7 +27,7 @@ class VariableOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::VariableOp op, PatternRewriter &rewriter) const final { auto variableType = tosa::getVariableType(op); - auto newVariable = rewriter.create( + auto newVariable = mlir::ml_program::GlobalOp::create(rewriter, op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, op.getInitialValueAttr(), /*sym_visibility=*/nullptr); newVariable.setPrivate(); @@ -45,7 +45,7 @@ class VariableWriteOpConverter PatternRewriter &rewriter) const final { auto globalSymbolRef = SymbolRefAttr::get(rewriter.getContext(), op.getName()); - auto newVariableWrite = rewriter.create( + auto newVariableWrite = ml_program::GlobalStoreOp::create(rewriter, op.getLoc(), globalSymbolRef, op.getInput1()); rewriter.replaceOp(op, newVariableWrite); return success(); @@ -60,7 +60,7 @@ class VariableReadOpConverter : public OpRewritePattern { PatternRewriter &rewriter) const final { auto globalSymbolRef = SymbolRefAttr::get(rewriter.getContext(), op.getName()); - auto newVariableRead = rewriter.create( + auto newVariableRead = ml_program::GlobalLoadOp::create(rewriter, op.getLoc(), op.getType(), globalSymbolRef); rewriter.replaceOp(op, newVariableRead); diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp index 03f9d20ad69de..9188b8d1368da 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -30,7 +30,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion, auto yield = cast(headBlock->getTerminator()); rewriter.setInsertionPoint(yield); - rewriter.create(yield.getLoc(), yield.getInputs()); + scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs()); rewriter.eraseOp(yield); headBlock->eraseArguments(0, headBlock->getNumArguments()); @@ -47,12 +47,12 @@ static void inlineWhileCase(Region &srcRegion, Region &dstRegion, rewriter.setInsertionPoint(yield); if (isCond) { auto condition = - rewriter.create(yield.getLoc(), yield.getOperand(0)); - rewriter.create(yield.getLoc(), condition, + tensor::ExtractOp::create(rewriter, yield.getLoc(), yield.getOperand(0)); + scf::ConditionOp::create(rewriter, yield.getLoc(), condition, headBlock->getArguments()); } else { rewriter.setInsertionPoint(yield); - rewriter.create(yield.getLoc(), yield.getInputs()); + scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs()); } rewriter.eraseOp(yield); } @@ -66,8 +66,8 @@ class IfOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::IfOp op, PatternRewriter &rewriter) const final { auto condition = - rewriter.create(op.getLoc(), op.getCondition()); - auto newIf = rewriter.create(op.getLoc(), op.getResultTypes(), + tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition()); + auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(), condition, true); inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(), @@ -88,7 +88,7 @@ class ScatterOpConverter : public OpRewritePattern { static Value createIndexConst(OpBuilder &builder, Location loc, int64_t value) { - return builder.create(loc, value); + return arith::ConstantIndexOp::create(builder, loc, value); } public: @@ -119,8 +119,8 @@ class ScatterOpConverter : public OpRewritePattern { auto n = ivs[0]; // Read the index and cast it to index type - auto index = builder.create(loc, indices, ivs); - auto castIndex = builder.create( + auto index = tensor::ExtractOp::create(builder, loc, indices, ivs); + auto castIndex = arith::IndexCastOp::create(builder, loc, builder.getIndexType(), index); // Offset, sizes, and strides for the input tensor @@ -130,12 +130,12 @@ class ScatterOpConverter : public OpRewritePattern { llvm::SmallVector sizes = {one, one, dimC}; llvm::SmallVector strides = {one, one, one}; - auto slice = builder.create( + auto slice = tensor::ExtractSliceOp::create(builder, loc, input, inputOffset, sizes, strides); // Insert the slice into the output accumulator tensor. llvm::SmallVector outputOffset = {n, castIndex, zero}; - auto updated = builder.create( + auto updated = tensor::InsertSliceOp::create(builder, loc, slice, args[0], outputOffset, sizes, strides); return {updated}; @@ -155,7 +155,7 @@ class WhileOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::WhileOp op, PatternRewriter &rewriter) const final { - auto newWhile = rewriter.create( + auto newWhile = scf::WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(), op.getInputList()); rewriter.createBlock(&newWhile.getBefore()); rewriter.createBlock(&newWhile.getAfter()); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 9ceb5c0c7f2fe..e51cbfb702083 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -308,13 +308,13 @@ class SliceConverter : public OpConversionPattern { if (ShapedType::isStatic(sizes.back())) continue; - auto dim = rewriter.create(loc, input, index); - auto offset = rewriter.create( + auto dim = tensor::DimOp::create(rewriter, loc, input, index); + auto offset = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(sliceStarts[index])); - dynSizes.push_back(rewriter.create(loc, dim, offset)); + dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset)); } - auto newSliceOp = rewriter.create( + auto newSliceOp = tensor::ExtractSliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), rewriter.getDenseI64ArrayAttr(sizes), @@ -362,7 +362,7 @@ class PadConverter : public OpConversionPattern { Value padConstant = rewriter.createOrFold( loc, padOp.getPadConst(), - ValueRange({rewriter.create(loc, 0)})); + ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)})); if (!padConstant) { return rewriter.notifyMatchFailure( @@ -376,15 +376,15 @@ class PadConverter : public OpConversionPattern { highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value lowVal = rewriter.create( + Value lowVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i])); - Value highVal = rewriter.create( + Value highVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); lowValues.push_back(lowVal); highValues.push_back(highVal); } - auto newPadOp = rewriter.create( + auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input, lowValues, highValues, padConstant); rewriter.replaceOp(padOp, newPadOp.getResult()); @@ -403,7 +403,7 @@ struct ConcatConverter : public OpConversionPattern { Location loc = op.getLoc(); int axis = op.getAxis(); Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis)); int64_t rank = resultType.getRank(); SmallVector strides(rank, rewriter.getIndexAttr(1)); @@ -440,7 +440,7 @@ struct ConcatConverter : public OpConversionPattern { } } - Value result = rewriter.create( + Value result = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType(), dynDims); for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index d6f9495b2567c..9e4eb347b3689 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -226,7 +226,7 @@ struct BroadcastOpToArmSMELowering (srcVectorType && (srcVectorType.getRank() == 0))) { // Broadcast scalar or 0-d vector to 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - broadcastOp1D = rewriter.create( + broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType, broadcastOp.getSource()); } else if (srcVectorType && (srcVectorType.getRank() == 1)) // Value to broadcast is already a 1-d vector, nothing to do. @@ -234,13 +234,13 @@ struct BroadcastOpToArmSMELowering else return failure(); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to broadcast the value // to each tile slice. - auto nextTile = b.create( + auto nextTile = arm_sme::InsertTileSliceOp::create(b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; @@ -292,14 +292,14 @@ struct SplatOpToArmSMELowering : public OpRewritePattern { // First, broadcast the scalar to a 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - Value broadcastOp1D = rewriter.create( + Value broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType, splatOp.getInput()); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { - auto nextTile = b.create( + auto nextTile = arm_sme::InsertTileSliceOp::create(b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; @@ -370,21 +370,21 @@ struct TransposeOpToArmSMELowering // Allocate buffer to store input tile to. Value vscale = - rewriter.create(loc, rewriter.getIndexType()); - Value minTileSlices = rewriter.create( + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + Value minTileSlices = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0))); Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); Value numTileSlices = - rewriter.create(loc, vscale, minTileSlices); + arith::MulIOp::create(rewriter, loc, vscale, minTileSlices); auto bufferType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, tileType.getElementType()); - auto buffer = rewriter.create( + auto buffer = memref::AllocaOp::create(rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices}); // Store input tile. - auto tileStoreOp = rewriter.create( + auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input, buffer, ValueRange{c0, c0}); // Reload input tile vertically. @@ -489,9 +489,9 @@ struct VectorOuterProductToArmSMELowering VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); Value lhsMask = - rewriter.create(loc, operandMaskType, lhsMaskDim); + vector::CreateMaskOp::create(rewriter, loc, operandMaskType, lhsMaskDim); Value rhsMask = - rewriter.create(loc, operandMaskType, rhsMaskDim); + vector::CreateMaskOp::create(rewriter, loc, operandMaskType, rhsMaskDim); return std::make_pair(lhsMask, rhsMask); } @@ -531,7 +531,7 @@ struct VectorExtractToArmSMELowering } Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); - auto extractTileSlice = rewriter.create( + auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(rewriter, loc, sourceVector, sliceIndex); if (position.size() == 1) { @@ -593,9 +593,9 @@ struct VectorInsertToArmSMELowering if (position.size() == 2) { // Two indices case: Insert single element into tile. // We need to first extract the existing slice and update the element. - tileSlice = rewriter.create( + tileSlice = arm_sme::ExtractTileSliceOp::create(rewriter, loc, insertOp.getDest(), sliceIndex); - tileSlice = rewriter.create(loc, source, tileSlice, + tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice, position[1]); } @@ -642,22 +642,22 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern { auto loc = printOp.getLoc(); // Create a loop over the rows of the tile. - auto vscale = rewriter.create(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto minTileRows = - rewriter.create(loc, vectorType.getDimSize(0)); - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, minTileRows, vscale); - auto step = rewriter.create(loc, 1); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); { // Loop body. rewriter.setInsertionPointToStart(forOp.getBody()); // Extract the current row from the tile. Value rowIndex = forOp.getInductionVar(); - auto tileSlice = rewriter.create( + auto tileSlice = arm_sme::ExtractTileSliceOp::create(rewriter, loc, printOp.getSource(), rowIndex); // Print the row with a 1D vector.print. - rewriter.create(loc, tileSlice, + vector::PrintOp::create(rewriter, loc, tileSlice, printOp.getPunctuation()); } @@ -707,7 +707,7 @@ struct FoldTransferWriteOfExtractTileSlice Value mask = writeOp.getMask(); if (!mask) { auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); - mask = rewriter.create( + mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); } @@ -776,9 +776,9 @@ struct ExtractFromCreateMaskToPselLowering // Create the two 1-D masks at the location of the 2-D create_mask (which is // usually outside a loop). This prevents the need for later hoisting. rewriter.setInsertionPoint(createMaskOp); - auto rowMask = rewriter.create( + auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType, createMaskOp.getOperand(0)); - auto colMask = rewriter.create( + auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType, createMaskOp.getOperand(1)); rewriter.setInsertionPoint(extractOp); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 18adaa793787c..5b3e67509672a 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -412,22 +412,22 @@ struct PrepareContractToGPUMMA if (maps == infer({{m, k}, {k, n}, {m, n}})) return rewriter.notifyMatchFailure(op, "contraction already prepared"); if (maps == infer({{m, k}, {n, k}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); - lhs = rewriter.create(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { std::swap(rhs, lhs); - rhs = rewriter.create(loc, rhs, perm); - lhs = rewriter.create(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(rhs, lhs); - rhs = rewriter.create(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else { @@ -494,13 +494,13 @@ struct CombineTransferReadOpTranspose final // Fuse through the integer extend op. if (extOp) { if (isa(extOp)) - result = rewriter.create(loc, op.getType(), result) + result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result) .getResult(); else if (isa(extOp)) - result = rewriter.create(loc, op.getType(), result) + result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result) .getResult(); else - result = rewriter.create(loc, op.getType(), result) + result = arith::ExtFOp::create(rewriter, loc, op.getType(), result) .getResult(); } @@ -579,7 +579,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, } gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); - Value load = rewriter.create( + Value load = gpu::SubgroupMmaLoadMatrixOp::create(rewriter, op.getLoc(), type, op.getBase(), op.getIndices(), rewriter.getIndexAttr(*stride), isTranspose ? rewriter.getUnitAttr() : UnitAttr()); @@ -610,7 +610,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, } Value matrix = it->second; - auto store = rewriter.create( + auto store = gpu::SubgroupMmaStoreMatrixOp::create(rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(), rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; @@ -661,7 +661,7 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, return rewriter.notifyMatchFailure(op, "not a splat"); } - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, op.getLoc(), vectorType, DenseElementsAttr::get(vectorType, dense.getSplatValue())); valueMapping[op.getResult()] = result; @@ -743,7 +743,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, } // Adjust the load offset. - auto laneId = rewriter.create(loc, /*upperBound=*/nullptr); + auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); FailureOr offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { @@ -757,7 +757,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, getXferIndices(rewriter, op, *offsets, {laneId}, indices); - nvgpu::LdMatrixOp newOp = rewriter.create( + nvgpu::LdMatrixOp newOp = nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(), indices, *transpose, params->numTiles); valueMapping[op] = newOp->getResult(0); return success(); @@ -782,17 +782,17 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, "conversion to distributed non-ldmatrix compatible load"); } - Value laneId = rewriter.create(loc, /*upperBound=*/nullptr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); // This is the individual element type. Type loadedElType = regInfo->registerLLVMType; VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value fill = rewriter.create( + Value fill = arith::ConstantOp::create(rewriter, op.getLoc(), vectorType.getElementType(), rewriter.getZeroAttr(vectorType.getElementType())); Value result = - rewriter.create(op.getLoc(), fill, vectorType); + vector::SplatOp::create(rewriter, op.getLoc(), fill, vectorType); bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); @@ -809,16 +809,16 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, if (failed(coords)) return rewriter.notifyMatchFailure(op, "no coords"); - Value logicalValueId = rewriter.create( + Value logicalValueId = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = rewriter.create(loc, loadedElType, + Value el = vector::LoadOp::create(rewriter, loc, loadedElType, op.getBase(), newIndices); - result = rewriter.create(loc, el, result, i); + result = vector::InsertOp::create(rewriter, loc, el, result, i); } } else { if (auto vecType = dyn_cast(loadedElType)) { @@ -828,7 +828,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; innerIdx++) { - Value logicalValueId = rewriter.create( + Value logicalValueId = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( @@ -839,9 +839,9 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = rewriter.create(op.getLoc(), loadedElType, + Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType, op.getBase(), newIndices); - result = rewriter.create( + result = vector::InsertOp::create(rewriter, op.getLoc(), el, result, ArrayRef{i, innerIdx}); } } @@ -916,10 +916,10 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return rewriter.notifyMatchFailure(op, "not mma sync reg info"); VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value laneId = rewriter.create(loc, /*upperBound=*/nullptr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { - Value logicalValueId = rewriter.create( + Value logicalValueId = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( @@ -928,11 +928,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return rewriter.notifyMatchFailure(op, "no coords"); Value el = - rewriter.create(loc, matrix, ArrayRef{i}); + vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef{i}); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - rewriter.create(loc, el, op.getBase(), newIndices); + vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); @@ -1015,7 +1015,7 @@ convertExtractStridedSlice(RewriterBase &rewriter, else if (offsets[1]) sliceOffset[0] = (warpVectorShape[1] / offsets[1]); - Value newOp = rewriter.create( + Value newOp = vector::ExtractStridedSliceOp::create(rewriter, loc, sourceVector, sliceOffset, sliceShape, strides); valueMapping[op] = newOp; @@ -1035,7 +1035,7 @@ convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; - Value matmul = rewriter.create( + Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), /*b_transpose=*/UnitAttr()); valueMapping[op.getResult()] = matmul; @@ -1058,7 +1058,7 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, int64_t m = cast(op.getLhs().getType()).getShape()[0]; int64_t n = cast(op.getRhs().getType()).getShape()[0]; int64_t k = cast(op.getLhs().getType()).getShape()[1]; - Value matmul = rewriter.create( + Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; return success(); @@ -1076,12 +1076,12 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, auto splat = cast(op.getValue()).getSplatValue(); auto scalarConstant = - rewriter.create(op.getLoc(), splat.getType(), splat); + arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = cast(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = rewriter.create( + auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(), type, scalarConstant); valueMapping[op.getResult()] = matrix; return success(); @@ -1100,7 +1100,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, auto vecType = op.getResultVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = rewriter.create( + auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(), type, op.getSource()); valueMapping[op.getResult()] = matrix; return success(); @@ -1118,7 +1118,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, rewriter.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getInitArgs()); llvm::append_range(operands, newInitArgs); - scf::ForOp newLoop = rewriter.create( + scf::ForOp newLoop = scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); rewriter.eraseBlock(newLoop.getBody()); @@ -1189,7 +1189,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; yieldOperands.push_back(it->second); } - rewriter.create(op.getLoc(), yieldOperands); + scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); rewriter.eraseOp(op); @@ -1220,7 +1220,7 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op, resultType.getOperand()); } - Value newOp = rewriter.create( + Value newOp = gpu::SubgroupMmaElementwiseOp::create(rewriter, op->getLoc(), resultType, matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; return success(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 501d98862672d..d21833736a718 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -43,13 +43,13 @@ static Value insertOne(ConversionPatternRewriter &rewriter, assert(rank > 0 && "0-D vector corner case should have been handled already"); if (rank == 1) { auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create( + auto constant = LLVM::ConstantOp::create(rewriter, loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create(loc, llvmType, val1, val2, + return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2, constant); } - return rewriter.create(loc, val1, val2, pos); + return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos); } // Helper that picks the proper sequence for extracting. @@ -58,13 +58,13 @@ static Value extractOne(ConversionPatternRewriter &rewriter, Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank <= 1) { auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create( + auto constant = LLVM::ConstantOp::create(rewriter, loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create(loc, llvmType, val, + return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val, constant); } - return rewriter.create(loc, val, pos); + return LLVM::ExtractValueOp::create(rewriter, loc, val, pos); } // Helper that returns data layout alignment of a vector. @@ -141,7 +141,7 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, auto ptrsType = LLVM::getVectorType(pType, vectorType.getDimSize(0), /*isScalable=*/vectorType.getScalableDims()[0]); - return rewriter.create( + return LLVM::GEPOp::create(rewriter, loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), base, index); } @@ -152,7 +152,7 @@ static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult) { if (auto attr = dyn_cast(foldResult)) { auto intAttr = cast(attr); - return builder.create(loc, intAttr).getResult(); + return LLVM::ConstantOp::create(builder, loc, intAttr).getResult(); } return cast(foldResult); @@ -475,7 +475,7 @@ class ReductionNeutralFPMax {}; static Value createReductionNeutralValue(ReductionNeutralZero neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create(loc, llvmType, + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getZeroAttr(llvmType)); } @@ -483,7 +483,7 @@ static Value createReductionNeutralValue(ReductionNeutralZero neutral, static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); } @@ -491,7 +491,7 @@ static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); } @@ -499,7 +499,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getIntegerAttr( llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); @@ -509,7 +509,7 @@ static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( llvmType.getIntOrFloatBitWidth()))); @@ -519,7 +519,7 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( llvmType.getIntOrFloatBitWidth()))); @@ -529,7 +529,7 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( llvmType.getIntOrFloatBitWidth()))); @@ -539,7 +539,7 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( llvmType.getIntOrFloatBitWidth()))); @@ -550,7 +550,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { auto floatType = cast(llvmType); - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getFloatAttr( llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), @@ -562,7 +562,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { auto floatType = cast(llvmType); - return rewriter.create( + return LLVM::ConstantOp::create(rewriter, loc, llvmType, rewriter.getFloatAttr( llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), @@ -591,7 +591,7 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, auto vShape = vType.getShape(); assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); - Value baseVecLength = rewriter.create( + Value baseVecLength = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); @@ -599,11 +599,11 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, return baseVecLength; // For a scalable vector type, create and return `vScale * baseVecLength`. - Value vScale = rewriter.create(loc); + Value vScale = vector::VectorScaleOp::create(rewriter, loc); vScale = - rewriter.create(loc, rewriter.getI32Type(), vScale); + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale); Value scalableVecLength = - rewriter.create(loc, baseVecLength, vScale); + arith::MulIOp::create(rewriter, loc, baseVecLength, vScale); return scalableVecLength; } @@ -616,10 +616,10 @@ static Value createIntegerReductionArithmeticOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator) { - Value result = rewriter.create(loc, llvmType, vectorOperand); + Value result = LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand); if (accumulator) - result = rewriter.create(loc, accumulator, result); + result = ScalarOp::create(rewriter, loc, accumulator, result); return result; } @@ -631,11 +631,11 @@ template static Value createIntegerReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { - Value result = rewriter.create(loc, llvmType, vectorOperand); + Value result = LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand); if (accumulator) { Value cmp = - rewriter.create(loc, predicate, accumulator, result); - result = rewriter.create(loc, cmp, accumulator, result); + LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result); + result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result); } return result; } @@ -666,7 +666,7 @@ static Value createFPReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { Value result = - rewriter.create(loc, llvmType, vectorOperand, fmf); + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf); if (accumulator) { result = @@ -702,7 +702,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, const auto &floatSemantics = cast(llvmType).getFloatSemantics(); auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); auto denseValue = DenseElementsAttr::get(cast(vectorType), value); - return rewriter.create(loc, vectorType, denseValue); + return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue); } /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked @@ -717,7 +717,7 @@ lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, Value mask, LLVM::FastmathFlagsAttr fmf) { const Value vectorMaskNeutral = createMaskNeutralValue( rewriter, loc, llvmType, vectorOperand.getType()); - const Value selectedVectorByMask = rewriter.create( + const Value selectedVectorByMask = LLVM::SelectOp::create(rewriter, loc, mask, vectorOperand, vectorMaskNeutral); return createFPReductionComparisonOpLowering( rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); @@ -730,7 +730,7 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, Value accumulator, LLVM::FastmathFlagsAttr fmf) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - return rewriter.create(loc, llvmType, + return LLVMRedIntrinOp::create(rewriter, loc, llvmType, /*startValue=*/accumulator, vectorOperand, fmf); } @@ -745,7 +745,7 @@ lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, Value vectorOperand, Value accumulator) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - return rewriter.create(loc, llvmType, + return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, /*startValue=*/accumulator, vectorOperand); } @@ -758,7 +758,7 @@ static Value lowerPredicatedReductionWithStartValue( llvmType, accumulator); Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); - return rewriter.create(loc, llvmType, + return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, /*startValue=*/accumulator, vectorOperand, mask, vectorLength); } @@ -1071,7 +1071,7 @@ class VectorShuffleOpConversion // For rank 0 and 1, where both operands have *exactly* the same vector // type, there is direct shuffle support in LLVM. Use it! if (rank <= 1 && v1Type == v2Type) { - Value llvmShuffleOp = rewriter.create( + Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(rewriter, loc, adaptor.getV1(), adaptor.getV2(), llvm::to_vector_of(mask)); rewriter.replaceOp(shuffleOp, llvmShuffleOp); @@ -1085,7 +1085,7 @@ class VectorShuffleOpConversion eltType = arrayType.getElementType(); else eltType = cast(llvmType).getElementType(); - Value insert = rewriter.create(loc, llvmType); + Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType); int64_t insPos = 0; for (int64_t extPos : mask) { Value value = adaptor.getV1(); @@ -1122,7 +1122,7 @@ class VectorExtractElementOpConversion if (vectorType.getRank() == 0) { Location loc = extractEltOp.getLoc(); auto idxType = rewriter.getIndexType(); - auto zero = rewriter.create( + auto zero = LLVM::ConstantOp::create(rewriter, loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( @@ -1193,12 +1193,12 @@ class VectorExtractOpConversion if (!llvm::all_of(position, llvm::IsaPred)) { return failure(); } - extracted = rewriter.create( + extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted, getAsIntegers(position)); } if (extractsScalar) { - extracted = rewriter.create( + extracted = LLVM::ExtractElementOp::create(rewriter, loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back())); } @@ -1256,7 +1256,7 @@ class VectorInsertElementOpConversion if (vectorType.getRank() == 0) { Location loc = insertEltOp.getLoc(); auto idxType = rewriter.getIndexType(); - auto zero = rewriter.create( + auto zero = LLVM::ConstantOp::create(rewriter, loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( @@ -1342,7 +1342,7 @@ class VectorInsertOpConversion // llvm.extractvalue does not support dynamic dimensions. return failure(); } - sourceAggregate = rewriter.create( + sourceAggregate = LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getDest(), getAsIntegers(positionOf1DVectorWithinAggregate)); } else { @@ -1351,7 +1351,7 @@ class VectorInsertOpConversion sourceAggregate = adaptor.getDest(); } // Insert the scalar into the 1D vector. - sourceAggregate = rewriter.create( + sourceAggregate = LLVM::InsertElementOp::create(rewriter, loc, sourceAggregate.getType(), sourceAggregate, adaptor.getValueToStore(), getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector)); @@ -1359,7 +1359,7 @@ class VectorInsertOpConversion Value result = sourceAggregate; if (isNestedAggregate) { - result = rewriter.create( + result = LLVM::InsertValueOp::create(rewriter, loc, adaptor.getDest(), sourceAggregate, getAsIntegers(positionOf1DVectorWithinAggregate)); } @@ -1439,15 +1439,15 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern { auto loc = op.getLoc(); auto elemType = vType.getElementType(); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create(loc, vType, zero); + Value desc = vector::SplatOp::create(rewriter, loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { - Value extrLHS = rewriter.create(loc, op.getLhs(), i); - Value extrRHS = rewriter.create(loc, op.getRhs(), i); - Value extrACC = rewriter.create(loc, op.getAcc(), i); - Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); - desc = rewriter.create(loc, fma, desc, i); + Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i); + Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i); + Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i); + Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC); + desc = InsertOp::create(rewriter, loc, fma, desc, i); } rewriter.replaceOp(op, desc); return success(); @@ -1537,7 +1537,7 @@ class VectorTypeCastOpConversion desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); - auto zero = rewriter.create(loc, int64Ty, attr); + auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. @@ -1546,11 +1546,11 @@ class VectorTypeCastOpConversion int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); - auto size = rewriter.create(loc, int64Ty, sizeAttr); + auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), (*targetStrides)[index]); - auto stride = rewriter.create(loc, int64Ty, strideAttr); + auto stride = LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } @@ -1578,13 +1578,13 @@ class VectorCreateMaskOpConversion IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); auto loc = op->getLoc(); - Value indices = rewriter.create( + Value indices = LLVM::StepVectorOp::create(rewriter, loc, LLVM::getVectorType(idxType, dstType.getShape()[0], /*isScalable=*/true)); auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, adaptor.getOperands()[0]); - Value bounds = rewriter.create(loc, indices.getType(), bound); - Value comp = rewriter.create(loc, arith::CmpIPredicate::slt, + Value bounds = SplatOp::create(rewriter, loc, indices.getType(), bound); + Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); rewriter.replaceOp(op, comp); return success(); @@ -1741,15 +1741,15 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { switch (conversion) { case PrintConversion::ZeroExt64: - value = rewriter.create( + value = arith::ExtUIOp::create(rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::SignExt64: - value = rewriter.create( + value = arith::ExtSIOp::create(rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::Bitcast16: - value = rewriter.create( + value = LLVM::BitcastOp::create(rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value); break; case PrintConversion::None: @@ -1762,7 +1762,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { // Helper to emit a call. static void emitCall(ConversionPatternRewriter &rewriter, Location loc, Operation *ref, ValueRange params = ValueRange()) { - rewriter.create(loc, TypeRange(), SymbolRefAttr::get(ref), + LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref), params); } }; @@ -1782,8 +1782,8 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { // First insert it into a poison vector so we can shuffle it. auto vectorType = typeConverter->convertType(splatOp.getType()); Value poison = - rewriter.create(splatOp.getLoc(), vectorType); - auto zero = rewriter.create( + LLVM::PoisonOp::create(rewriter, splatOp.getLoc(), vectorType); + auto zero = LLVM::ConstantOp::create(rewriter, splatOp.getLoc(), typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); @@ -1796,7 +1796,7 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { } // For 1-d vector, we additionally do a `vectorshuffle`. - auto v = rewriter.create( + auto v = LLVM::InsertElementOp::create(rewriter, splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero); int64_t width = cast(splatOp.getType()).getDimSize(0); @@ -1832,26 +1832,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { return failure(); // Construct returned value. - Value desc = rewriter.create(loc, llvmNDVectorTy); + Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. - Value vdesc = rewriter.create(loc, llvm1DVectorTy); - auto zero = rewriter.create( + Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy); + auto zero = LLVM::ConstantOp::create(rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, + Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy, vdesc, adaptor.getInput(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); - v = rewriter.create(loc, v, v, zeroValues); + v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { - desc = rewriter.create(loc, desc, v, position); + desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position); }); rewriter.replaceOp(splatOp, desc); return success(); @@ -1921,12 +1921,12 @@ struct VectorDeinterleaveOpLowering auto deinterleaveResults = deinterleaveOp.getResultTypes(); auto packedOpResults = llvmTypeConverter->packOperationResults(deinterleaveResults); - auto intrinsic = rewriter.create( + auto intrinsic = LLVM::vector_deinterleave2::create(rewriter, loc, packedOpResults, adaptor.getSource()); - auto evenResult = rewriter.create( + auto evenResult = LLVM::ExtractValueOp::create(rewriter, loc, intrinsic->getResult(0), 0); - auto oddResult = rewriter.create( + auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc, intrinsic->getResult(0), 1); rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult}); @@ -1950,10 +1950,10 @@ struct VectorDeinterleaveOpLowering oddShuffleMask.push_back(i); } - auto poison = rewriter.create(loc, sourceType); - auto evenShuffle = rewriter.create( + auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType); + auto evenShuffle = LLVM::ShuffleVectorOp::create(rewriter, loc, adaptor.getSource(), poison, evenShuffleMask); - auto oddShuffle = rewriter.create( + auto oddShuffle = LLVM::ShuffleVectorOp::create(rewriter, loc, adaptor.getSource(), poison, oddShuffleMask); rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle}); @@ -1977,9 +1977,9 @@ struct VectorFromElementsLowering return rewriter.notifyMatchFailure(fromElementsOp, "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); - Value result = rewriter.create(loc, llvmType); + Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = rewriter.create(loc, val, result, idx); + result = vector::InsertOp::create(rewriter, loc, val, result, idx); rewriter.replaceOp(fromElementsOp, result); return success(); } @@ -2003,11 +2003,11 @@ struct VectorToElementsLowering if (element.use_empty()) continue; - auto constIdx = rewriter.create( + auto constIdx = LLVM::ConstantOp::create(rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx)); auto llvmType = typeConverter->convertType(element.getType()); - Value result = rewriter.create(loc, llvmType, + Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType, source, constIdx); results[idx] = result; } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 9f5b8fcca6c26..9b79b913f2d4b 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -132,9 +132,9 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, Value value) { if (hasRetVal) { assert(value && "Expected non-empty value"); - b.create(loc, value); + scf::YieldOp::create(b, loc, value); } else { - b.create(loc); + scf::YieldOp::create(b, loc); } } @@ -154,7 +154,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { return Value(); Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.getMask(), iv); + return vector::ExtractElementOp::create(b, loc, xferOp.getMask(), iv); } /// Helper function TransferOpConversion and TransferOp1dConversion. @@ -201,21 +201,21 @@ static Value generateInBoundsCheck( Value base = xferOp.getIndices()[*dim]; Value memrefIdx = affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); - cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, + cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim, memrefIdx); } // Condition check 2: Masked in? if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { if (cond) - cond = lb.create(cond, maskCond); + cond = arith::AndIOp::create(lb, cond, maskCond); else cond = maskCond; } // If the condition is non-empty, generate an SCF::IfOp. if (cond) { - auto check = lb.create( + auto check = scf::IfOp::create(lb, cond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { @@ -226,7 +226,7 @@ static Value generateInBoundsCheck( if (outOfBoundsCase) { maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); } else { - b.create(loc); + scf::YieldOp::create(b, loc); } }); @@ -303,14 +303,14 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { BufferAllocs result; auto bufferType = MemRefType::get({}, xferOp.getVectorType()); - result.dataBuffer = b.create(loc, bufferType); + result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType); if (xferOp.getMask()) { auto maskType = MemRefType::get({}, xferOp.getMask().getType()); - auto maskBuffer = b.create(loc, maskType); + auto maskBuffer = memref::AllocaOp::create(b, loc, maskType); b.setInsertionPoint(xferOp); - b.create(loc, xferOp.getMask(), maskBuffer); - result.maskBuffer = b.create(loc, maskBuffer, ValueRange()); + memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer); + result.maskBuffer = memref::LoadOp::create(b, loc, maskBuffer, ValueRange()); } return result; @@ -421,14 +421,14 @@ struct Strategy { auto bufferType = dyn_cast(buffer.getType()); auto vecType = dyn_cast(bufferType.getElementType()); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create( + auto newXferOp = vector::TransferReadOp::create(b, loc, vecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); maybeApplyPassLabel(b, newXferOp, options.targetRank); - b.create(loc, newXferOp.getVector(), buffer, storeIndices); + memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer, storeIndices); return newXferOp; } @@ -444,8 +444,8 @@ struct Strategy { Location loc = xferOp.getLoc(); auto bufferType = dyn_cast(buffer.getType()); auto vecType = dyn_cast(bufferType.getElementType()); - auto vec = b.create(loc, vecType, xferOp.getPadding()); - b.create(loc, vec, buffer, storeIndices); + auto vec = vector::SplatOp::create(b, loc, vecType, xferOp.getPadding()); + memref::StoreOp::create(b, loc, vec, buffer, storeIndices); return Value(); } @@ -506,11 +506,11 @@ struct Strategy { getXferIndices(b, xferOp, iv, xferIndices); Location loc = xferOp.getLoc(); - auto vec = b.create(loc, buffer, loadIndices); + auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto source = loopState.empty() ? xferOp.getBase() : loopState[0]; Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); - auto newXferOp = b.create( + auto newXferOp = vector::TransferWriteOp::create(b, loc, type, vec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -610,7 +610,7 @@ struct PrepareTransferReadConversion } Location loc = xferOp.getLoc(); - rewriter.create(loc, newXfer->getResult(0), + memref::StoreOp::create(rewriter, loc, newXfer->getResult(0), buffers.dataBuffer); rewriter.replaceOpWithNewOp(xferOp, buffers.dataBuffer); @@ -653,9 +653,9 @@ struct PrepareTransferWriteConversion Location loc = xferOp.getLoc(); auto buffers = allocBuffers(rewriter, xferOp); - rewriter.create(loc, xferOp.getVector(), + memref::StoreOp::create(rewriter, loc, xferOp.getVector(), buffers.dataBuffer); - auto loadedVec = rewriter.create(loc, buffers.dataBuffer); + auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer); rewriter.modifyOpInPlace(xferOp, [&]() { xferOp.getValueToStoreMutable().assign(loadedVec); xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); @@ -735,17 +735,17 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { auto signlessTargetVectorType = vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy)); auto targetVectorType = vectorType.cloneWith({}, legalIntTy); - value = rewriter.create(loc, signlessSourceVectorType, + value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType, value); if (value.getType() != signlessTargetVectorType) { if (width == 1 || intTy.isUnsigned()) - value = rewriter.create(loc, signlessTargetVectorType, + value = arith::ExtUIOp::create(rewriter, loc, signlessTargetVectorType, value); else - value = rewriter.create(loc, signlessTargetVectorType, + value = arith::ExtSIOp::create(rewriter, loc, signlessTargetVectorType, value); } - value = rewriter.create(loc, targetVectorType, value); + value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value); vectorType = targetVectorType; } @@ -763,28 +763,28 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { std::multiplies()); auto flatVectorType = VectorType::get({flatLength}, vectorType.getElementType()); - value = rewriter.create(loc, flatVectorType, value); + value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value); } vector::PrintOp firstClose; SmallVector loopIndices; for (unsigned d = 0; d < shape.size(); d++) { // Setup loop bounds and step. - Value lowerBound = rewriter.create(loc, 0); - Value upperBound = rewriter.create(loc, shape[d]); - Value step = rewriter.create(loc, 1); + Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value upperBound = arith::ConstantIndexOp::create(rewriter, loc, shape[d]); + Value step = arith::ConstantIndexOp::create(rewriter, loc, 1); if (!scalableDimensions.empty() && scalableDimensions[d]) { - auto vscale = rewriter.create( + auto vscale = vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); - upperBound = rewriter.create(loc, upperBound, vscale); + upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale); } - auto lastIndex = rewriter.create(loc, upperBound, step); + auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step); // Create a loop to print the elements surrounded by parentheses. - rewriter.create(loc, vector::PrintPunctuation::Open); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); auto loop = - rewriter.create(loc, lowerBound, upperBound, step); - auto printClose = rewriter.create( + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); + auto printClose = vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close); if (!firstClose) firstClose = printClose; @@ -794,13 +794,13 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { // Print a comma after all but the last element. rewriter.setInsertionPointToStart(loop.getBody()); - auto notLastIndex = rewriter.create( + auto notLastIndex = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); - rewriter.create(loc, notLastIndex, + scf::IfOp::create(rewriter, loc, notLastIndex, [&](OpBuilder &builder, Location loc) { - builder.create( + vector::PrintOp::create(builder, loc, vector::PrintPunctuation::Comma); - builder.create(loc); + scf::YieldOp::create(builder, loc); }); rewriter.setInsertionPointToStart(loop.getBody()); @@ -811,10 +811,10 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { Value flatIndex; auto currentStride = 1; for (int d = shape.size() - 1; d >= 0; d--) { - auto stride = rewriter.create(loc, currentStride); - auto index = rewriter.create(loc, stride, loopIndices[d]); + auto stride = arith::ConstantIndexOp::create(rewriter, loc, currentStride); + auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]); if (flatIndex) - flatIndex = rewriter.create(loc, flatIndex, index); + flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index); else flatIndex = index; currentStride *= shape[d]; @@ -822,12 +822,12 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern { // Print the scalar elements in the inner most loop. auto element = - rewriter.create(loc, value, flatIndex); - rewriter.create(loc, element, + vector::ExtractElementOp::create(rewriter, loc, value, flatIndex); + vector::PrintOp::create(rewriter, loc, element, vector::PrintPunctuation::NoPunctuation); rewriter.setInsertionPointAfter(firstClose); - rewriter.create(loc, printOp.getPunctuation()); + vector::PrintOp::create(rewriter, loc, printOp.getPunctuation()); rewriter.eraseOp(printOp); return success(); } @@ -918,7 +918,7 @@ struct TransferOpConversion : public VectorToSCFPattern { "Failed to unpack one vector dim."); auto castedDataBuffer = - locB.create(*castedDataType, dataBuffer); + vector::TypeCastOp::create(locB, *castedDataType, dataBuffer); // If the xferOp has a mask: Find and cast mask buffer. Value castedMaskBuffer; @@ -937,21 +937,21 @@ struct TransferOpConversion : public VectorToSCFPattern { auto maskBufferType = cast(maskBuffer.getType()); MemRefType castedMaskType = *unpackOneDim(maskBufferType); castedMaskBuffer = - locB.create(castedMaskType, maskBuffer); + vector::TypeCastOp::create(locB, castedMaskType, maskBuffer); } } // Loop bounds and step. - auto lb = locB.create(0); - auto ub = locB.create( + auto lb = arith::ConstantIndexOp::create(locB, 0); + auto ub = arith::ConstantIndexOp::create(locB, castedDataType->getDimSize(castedDataType->getRank() - 1)); - auto step = locB.create(1); + auto step = arith::ConstantIndexOp::create(locB, 1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. auto loopState = Strategy::initialLoopState(xferOp); // Generate for loop. - auto result = locB.create( + auto result = scf::ForOp::create(locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { Type stateType = loopState.empty() ? Type() : loopState[0].getType(); @@ -977,7 +977,7 @@ struct TransferOpConversion : public VectorToSCFPattern { SmallVector loadIndices; getMaskBufferLoadIndices(xferOp, castedMaskBuffer, loadIndices, iv); - auto mask = b.create(loc, castedMaskBuffer, + auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer, loadIndices); rewriter.modifyOpInPlace(newXfer, [&]() { newXfer.getMaskMutable().assign(mask); @@ -1121,29 +1121,29 @@ struct ScalableTransposeTransferWriteConversion auto transposeSource = transposeOp.getVector(); SmallVector transposeSourceSlices = llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { - return rewriter.create(loc, transposeSource, idx); + return vector::ExtractOp::create(rewriter, loc, transposeSource, idx); }); // Loop bounds and step. - auto lb = rewriter.create(loc, 0); + auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0); auto ub = maskDims->empty() ? Value(createVscaleMultiple(vectorType.getDimSize(0))) : vector::getAsValues(rewriter, loc, maskDims->front()).front(); - auto step = rewriter.create(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); // Generate a new mask for the slice. VectorType sliceType = VectorType::Builder(vectorType).dropDim(0); Value sliceMask = nullptr; if (!maskDims->empty()) { - sliceMask = rewriter.create( + sliceMask = vector::CreateMaskOp::create(rewriter, loc, sliceType.clone(rewriter.getI1Type()), ArrayRef(*maskDims).drop_front()); } Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{}; ValueRange initLoopArgs = initDest ? initDest : ValueRange{}; - auto result = rewriter.create( + auto result = scf::ForOp::create(rewriter, loc, lb, ub, step, initLoopArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) { // Indices for the new transfer op. @@ -1153,23 +1153,23 @@ struct ScalableTransposeTransferWriteConversion // Extract a transposed slice from the source vector. SmallVector transposeElements = llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { - return b.create( + return vector::ExtractOp::create(b, loc, transposeSourceSlices[idx], iv); }); - auto sliceVec = b.create(loc, sliceType, + auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType, transposeElements); // Create the transfer_write for the slice. Value dest = loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front(); - auto newWriteOp = b.create( + auto newWriteOp = vector::TransferWriteOp::create(b, loc, sliceVec, dest, xferIndices, ArrayRef(writeOp.getInBoundsValues()).drop_front()); if (sliceMask) newWriteOp.getMaskMutable().assign(sliceMask); // Yield from the loop. - b.create(loc, loopIterArgs.empty() + scf::YieldOp::create(b, loc, loopIterArgs.empty() ? ValueRange{} : newWriteOp.getResult()); }); @@ -1209,7 +1209,7 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, llvm::SmallVector indices({i}); Location loc = xferOp.getLoc(); - auto newMask = b.create(loc, xferOp.getMask(), indices); + auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices); newXferOp.getMaskMutable().assign(newMask); } @@ -1263,7 +1263,7 @@ struct UnrollTransferReadConversion if (auto insertOp = getInsertOp(xferOp)) return insertOp.getDest(); Location loc = xferOp.getLoc(); - return rewriter.create(loc, xferOp.getVectorType(), + return vector::SplatOp::create(rewriter, loc, xferOp.getVectorType(), xferOp.getPadding()); } @@ -1319,7 +1319,7 @@ struct UnrollTransferReadConversion // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = arith::ConstantIndexOp::create(rewriter, loc, i); // FIXME: Rename this lambda - it does much more than just // in-bounds-check generation. @@ -1338,7 +1338,7 @@ struct UnrollTransferReadConversion auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create( + auto newXferOp = vector::TransferReadOp::create(b, loc, newXferVecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); @@ -1348,10 +1348,10 @@ struct UnrollTransferReadConversion if (newXferVecType.getRank() == 0) { // vector.insert does not accept rank-0 as the non-indexed // argument. Extract the scalar before inserting. - valToInser = b.create(loc, valToInser, + valToInser = vector::ExtractOp::create(b, loc, valToInser, SmallVector()); } - return b.create(loc, valToInser, vec, + return vector::InsertOp::create(b, loc, valToInser, vec, insertionIndices); }, /*outOfBoundsCase=*/ @@ -1462,7 +1462,7 @@ struct UnrollTransferWriteConversion // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = arith::ConstantIndexOp::create(rewriter, loc, i); auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), @@ -1479,19 +1479,19 @@ struct UnrollTransferWriteConversion extractionIndices.push_back(b.getI64IntegerAttr(i)); auto extracted = - b.create(loc, vec, extractionIndices); + vector::ExtractOp::create(b, loc, vec, extractionIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); Value xferVec; if (inputVectorTy.getRank() == 1) { // When target-rank=0, unrolling would causes the vector input // argument into `transfer_write` to become a scalar. We solve // this by broadcasting the scalar to a 0D vector. - xferVec = b.create( + xferVec = vector::BroadcastOp::create(b, loc, VectorType::get({}, extracted.getType()), extracted); } else { xferVec = extracted; } - auto newXferOp = b.create( + auto newXferOp = vector::TransferWriteOp::create(b, loc, sourceType, xferVec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -1574,18 +1574,18 @@ struct Strategy1d { b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), /*inBoundsCase=*/ [&](OpBuilder &b, Location loc) { - Value val = b.create(loc, xferOp.getBase(), indices); - return b.create(loc, val, vec, iv); + Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices); + return vector::InsertElementOp::create(b, loc, val, vec, iv); }, /*outOfBoundsCase=*/ [&](OpBuilder & /*b*/, Location loc) { return vec; }); - b.create(loc, nextVec); + scf::YieldOp::create(b, loc, nextVec); } static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { // Inititalize vector with padding value. Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.getVectorType(), + return vector::SplatOp::create(b, loc, xferOp.getVectorType(), xferOp.getPadding()); } }; @@ -1604,10 +1604,10 @@ struct Strategy1d { b, xferOp, iv, dim, /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { auto val = - b.create(loc, xferOp.getVector(), iv); - b.create(loc, val, xferOp.getBase(), indices); + vector::ExtractElementOp::create(b, loc, xferOp.getVector(), iv); + memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices); }); - b.create(loc); + scf::YieldOp::create(b, loc); } static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { @@ -1668,15 +1668,15 @@ struct TransferOp1dConversion : public VectorToSCFPattern { // Loop bounds, step, state... Location loc = xferOp.getLoc(); auto vecType = xferOp.getVectorType(); - auto lb = rewriter.create(loc, 0); + auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0); Value ub = - rewriter.create(loc, vecType.getDimSize(0)); + arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0)); if (vecType.isScalable()) { Value vscale = - rewriter.create(loc, rewriter.getIndexType()); - ub = rewriter.create(loc, ub, vscale); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + ub = arith::MulIOp::create(rewriter, loc, ub, vscale); } - auto step = rewriter.create(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loopState = Strategy1d::initialLoopState(rewriter, xferOp); // Generate for loop. diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 21d8e1d9f1156..2466e72277441 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -147,18 +147,18 @@ static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter, Location loc, Value dynamicIndex, int64_t kPoisonIndex, unsigned vectorSize) { if (llvm::isPowerOf2_32(vectorSize)) { - Value inBoundsMask = rewriter.create( + Value inBoundsMask = spirv::ConstantOp::create(rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1)); - return rewriter.create(loc, dynamicIndex, + return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex, inBoundsMask); } - Value poisonIndex = rewriter.create( + Value poisonIndex = spirv::ConstantOp::create(rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex)); Value cmpResult = - rewriter.create(loc, dynamicIndex, poisonIndex); - return rewriter.create( + spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex); + return spirv::SelectOp::create(rewriter, loc, cmpResult, spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter), dynamicIndex); @@ -427,7 +427,7 @@ static SmallVector extractAllElements( Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { - values.push_back(rewriter.create( + values.push_back(spirv::CompositeExtractOp::create(rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(), rewriter.getI32ArrayAttr({i}))); } @@ -481,16 +481,16 @@ struct VectorReductionPattern final : OpConversionPattern { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ if (llvm::isa(resultType)) { \ - result = rewriter.create(loc, resultType, result, next); \ + result = spirv::iop::create(rewriter, loc, resultType, result, next); \ } else { \ assert(llvm::isa(resultType)); \ - result = rewriter.create(loc, resultType, result, next); \ + result = spirv::fop::create(rewriter, loc, resultType, result, next); \ } \ break #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ + result = fop::create(rewriter, loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); @@ -537,7 +537,7 @@ struct VectorReductionFloatMinMax final #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ + result = fop::create(rewriter, loc, resultType, result, next); \ break INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); @@ -613,7 +613,7 @@ struct VectorShuffleOpConvert final auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( Value scalarOrVec, int32_t idx) -> Value { if (auto vecTy = dyn_cast(scalarOrVec.getType())) - return rewriter.create(loc, scalarOrVec, + return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec, idx); assert(idx == 0 && "Invalid scalar element index"); @@ -712,10 +712,10 @@ struct VectorDeinterleaveOpConvert final // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to // use `spirv::CompositeExtractOp`. if (n == 2) { - auto elem0 = rewriter.create( + auto elem0 = spirv::CompositeExtractOp::create(rewriter, loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0})); - auto elem1 = rewriter.create( + auto elem1 = spirv::CompositeExtractOp::create(rewriter, loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1})); rewriter.replaceOp(deinterleaveOp, {elem0, elem1}); @@ -733,11 +733,11 @@ struct VectorDeinterleaveOpConvert final llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; }); // Create two SPIR-V shuffles. - auto shuffleEven = rewriter.create( + auto shuffleEven = spirv::VectorShuffleOp::create(rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesEven)); - auto shuffleOdd = rewriter.create( + auto shuffleOdd = spirv::VectorShuffleOp::create(rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesOdd)); @@ -781,7 +781,7 @@ struct VectorLoadOpConverter final // to a scalar. Value castedAccessChain = (vectorType.getNumElements() == 1) ? accessChain - : rewriter.create( + : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, accessChain); rewriter.replaceOpWithNewOp(loadOp, spirvVectorType, @@ -823,7 +823,7 @@ struct VectorStoreOpConverter final // to a scalar. Value castedAccessChain = (vectorType.getNumElements() == 1) ? accessChain - : rewriter.create( + : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, accessChain); rewriter.replaceOpWithNewOp(storeOp, castedAccessChain, @@ -905,9 +905,9 @@ struct VectorReductionToIntDotProd final auto v4i8Type = VectorType::get({4}, i8Type); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); - lhsIn = rewriter.create( + lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, ValueRange{lhsIn, zero}); - rhsIn = rewriter.create( + rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, ValueRange{rhsIn, zero}); } @@ -971,14 +971,14 @@ struct VectorReductionToFPDotProd final Attribute oneAttr = rewriter.getFloatAttr(vectorType.getElementType(), 1.0); oneAttr = SplatElementsAttr::get(vectorType, oneAttr); - rhs = rewriter.create(loc, vectorType, oneAttr); + rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr); } assert(lhs); assert(rhs); - Value res = rewriter.create(loc, resultType, lhs, rhs); + Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs); if (acc) - res = rewriter.create(loc, acc, res); + res = spirv::FAddOp::create(rewriter, loc, acc, res); rewriter.replaceOp(op, res); return success(); @@ -1013,7 +1013,7 @@ struct VectorStepOpConvert final : OpConversionPattern { source.reserve(numElements); for (int64_t i = 0; i < numElements; ++i) { Attribute intAttr = rewriter.getIntegerAttr(intType, i); - Value constOp = rewriter.create(loc, intType, intAttr); + Value constOp = spirv::ConstantOp::create(rewriter, loc, intType, intAttr); source.push_back(constOp); } rewriter.replaceOpWithNewOp(stepOp, dstType, @@ -1056,7 +1056,7 @@ struct VectorToElementOpConvert final if (element.use_empty()) continue; - Value result = rewriter.create( + Value result = spirv::CompositeExtractOp::create(rewriter, loc, elementType, adaptor.getSource(), rewriter.getI32ArrayAttr({static_cast(idx)})); results[idx] = result; diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 0ec7129a40a66..255013ee07c88 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -108,7 +108,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { - ndDesc = rewriter.create(loc, descType, src, + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, getAsOpFoldResult(offsets)); } else { // In case of any dynamic shapes, source's shape and strides have to be @@ -116,7 +116,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, SmallVector sourceDims; unsigned srcRank = srcTy.getRank(); for (unsigned i = 0; i < srcRank; ++i) - sourceDims.push_back(rewriter.create(loc, src, i)); + sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); SmallVector constOffsets; SmallVector dynOffsets; @@ -135,17 +135,17 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, // Compute strides in reverse order. SmallVector dynStrides; - Value accStride = rewriter.create(loc, 1); + Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); // Last stride is guaranteed to be static and unit. for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { accStride = - rewriter.create(loc, accStride, sourceDims[i + 1]); + arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); if (strides[i] == ShapedType::kDynamic) dynStrides.push_back(accStride); } std::reverse(dynStrides.begin(), dynStrides.end()); - ndDesc = rewriter.create( + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), @@ -200,7 +200,7 @@ struct TransferReadLowering : public OpRewritePattern { ArrayRef{1, 0}); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadOp = rewriter.create( + auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); @@ -238,7 +238,7 @@ struct TransferWriteLowering // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeOp = - rewriter.create(loc, writeOp.getVector(), ndDesc, + xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(writeOp, storeOp); @@ -269,7 +269,7 @@ struct LoadLowering : public OpRewritePattern { // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadNdOp = rewriter.create( + auto loadNdOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); @@ -303,7 +303,7 @@ struct StoreLowering : public OpRewritePattern { // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeNdOp = - rewriter.create(loc, vector, ndDesc, + xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(storeOp, storeNdOp); @@ -346,7 +346,7 @@ struct ContractionLowering : public OpRewritePattern { return rewriter.notifyMatchFailure(contractOp, "Invalid operand dimensions"); - auto dpasOp = rewriter.create( + auto dpasOp = xegpu::DpasOp::create(rewriter, loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 4613d14461969..79859f0513e84 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index fd2ba0683786e..4a1a2487806a7 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -100,8 +100,8 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); Type allBitsType = rewriter.getIntegerType(bitwidth); auto allBitsVecType = VectorType::get({1}, allBitsType); - Value bitcast = rewriter.create(loc, allBitsVecType, val); - Value scalar = rewriter.create(loc, bitcast, 0); + Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val); + Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0); return scalar; } @@ -120,25 +120,25 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( SmallVector loadAttrs; patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop); Value initialLoad = - rewriter.create(loc, dataType, invariantArgs, loadAttrs); + RawBufferLoadOp::create(rewriter, loc, dataType, invariantArgs, loadAttrs); Block *currentBlock = rewriter.getInsertionBlock(); Block *afterAtomic = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc}); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, loopBlock, initialLoad); + cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad); rewriter.setInsertionPointToEnd(loopBlock); Value prevLoad = loopBlock->getArgument(0); - Value operated = rewriter.create(loc, data, prevLoad); + Value operated = ArithOp::create(rewriter, loc, data, prevLoad); dataType = operated.getType(); SmallVector cmpswapAttrs; patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate); SmallVector cmpswapArgs = {operated, prevLoad}; cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end()); - Value atomicRes = rewriter.create( + Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType, cmpswapArgs, cmpswapAttrs); // We care about exact bitwise equality here, so do some bitcasts. @@ -151,13 +151,13 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( if (auto floatDataTy = dyn_cast(dataType)) { Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); prevLoadForCompare = - rewriter.create(loc, equivInt, prevLoad); + arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad); atomicResForCompare = - rewriter.create(loc, equivInt, atomicRes); + arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes); } - Value canLeave = rewriter.create( + Value canLeave = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare); - rewriter.create(loc, canLeave, afterAtomic, ValueRange{}, + cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{}, loopBlock, atomicRes); rewriter.eraseOp(atomicOp); return success(); diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index 9a368f372c296..cba0611f05ef8 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -54,9 +54,9 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp) { VectorType vectorType = maskedOp.getVectorType(); - Value load = builder.create( + Value load = vector::LoadOp::create(builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); - Value res = builder.create( + Value res = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru()); return res; } @@ -87,7 +87,7 @@ struct MaskedLoadLowering final : OpRewritePattern { SmallVector indices = maskedOp.getIndices(); auto stridedMetadata = - rewriter.create(loc, src); + memref::ExtractStridedMetadataOp::create(rewriter, loc, src); SmallVector strides = stridedMetadata.getConstifiedMixedStrides(); SmallVector sizes = stridedMetadata.getConstifiedMixedSizes(); @@ -101,46 +101,46 @@ struct MaskedLoadLowering final : OpRewritePattern { // delta = bufferSize - linearizedOffset Value vectorSizeOffset = - rewriter.create(loc, vectorSize); + arith::ConstantIndexOp::create(rewriter, loc, vectorSize); Value linearIndex = getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); Value totalSize = getValueOrCreateConstantIndexOp( rewriter, loc, linearizedInfo.linearizedSize); - Value delta = rewriter.create(loc, totalSize, linearIndex); + Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex); // 1) check if delta < vectorSize - Value isOutofBounds = rewriter.create( + Value isOutofBounds = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); // 2) check if (detla % elements_per_word != 0) - Value elementsPerWord = rewriter.create( + Value elementsPerWord = arith::ConstantIndexOp::create(rewriter, loc, llvm::divideCeil(32, elementBitWidth)); - Value isNotWordAligned = rewriter.create( + Value isNotWordAligned = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, - rewriter.create(loc, delta, elementsPerWord), - rewriter.create(loc, 0)); + arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord), + arith::ConstantIndexOp::create(rewriter, loc, 0)); // We take the fallback of maskedload default lowering only it is both // out-of-bounds and not word aligned. The fallback ensures correct results // when loading at the boundary of the buffer since buffer load returns // inconsistent zeros for the whole word when boundary is crossed. Value ifCondition = - rewriter.create(loc, isOutofBounds, isNotWordAligned); + arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned); auto thenBuilder = [&](OpBuilder &builder, Location loc) { Operation *read = builder.clone(*maskedOp.getOperation()); read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr()); Value readResult = read->getResult(0); - builder.create(loc, readResult); + scf::YieldOp::create(builder, loc, readResult); }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp); - rewriter.create(loc, res); + scf::YieldOp::create(rewriter, loc, res); }; auto ifOp = - rewriter.create(loc, ifCondition, thenBuilder, elseBuilder); + scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder); rewriter.replaceOp(maskedOp, ifOp); diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp index 195f59d625554..4771d655e36c2 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp @@ -37,7 +37,7 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final return rewriter.notifyMatchFailure(metadataOp, "not a fat raw buffer cast"); Location loc = castOp.getLoc(); - auto sourceMetadata = rewriter.create( + auto sourceMetadata = memref::ExtractStridedMetadataOp::create(rewriter, loc, castOp.getSource()); SmallVector results; if (metadataOp.getBaseBuffer().use_empty()) { @@ -48,13 +48,13 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final if (baseBufferType == castOp.getResult().getType()) { results.push_back(castOp.getResult()); } else { - results.push_back(rewriter.create( + results.push_back(memref::ReinterpretCastOp::create(rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0, /*sizes=*/ArrayRef{}, /*strides=*/ArrayRef{})); } } if (castOp.getResetOffset()) - results.push_back(rewriter.create(loc, 0)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); else results.push_back(sourceMetadata.getOffset()); llvm::append_range(results, sourceMetadata.getSizes()); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 12b375b373fa9..111d93e39c15e 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" @@ -76,8 +77,8 @@ static SmallVector getTileSizes(Location loc, amx::TileType tType, auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); return SmallVector{ - rewriter.create(loc, llvmInt16Type, mattr), - rewriter.create(loc, llvmInt16Type, nattr)}; + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr), + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer @@ -95,7 +96,7 @@ static Value getStride(Location loc, MemRefType mType, Value base, // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = rewriter.create(loc, llvmInt64Type, attr); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); return rewriter .create(loc, llvmInt64Type, scale, memrefDescriptor.stride(rewriter, loc, preLast)) @@ -103,7 +104,7 @@ static Value getStride(Location loc, MemRefType mType, Value base, } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return rewriter.create(loc, llvmInt64Type, attr) + return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr) .getResult(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index bd4b2e56808b6..91f85f2f366e8 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" @@ -242,7 +243,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -1284,7 +1285,7 @@ mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, map = foldAttributesIntoMap(b, map, operands, valueOperands); composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin); assert(map); - return b.create(loc, map, valueOperands); + return AffineApplyOp::create(b, loc, map, valueOperands); } AffineApplyOp @@ -1391,7 +1392,7 @@ static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, SmallVector valueOperands; map = foldAttributesIntoMap(b, map, operands, valueOperands); composeMultiResultAffineMap(map, valueOperands); - return b.create(loc, b.getIndexType(), map, valueOperands); + return OpTy::create(b, loc, b.getIndexType(), map, valueOperands); } AffineMinOp @@ -1749,6 +1750,32 @@ void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, } } +AffineDmaStartOp AffineDmaStartOp::create( + OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap, + ValueRange srcIndices, Value destMemRef, AffineMap dstMap, + ValueRange destIndices, Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements, Value stride, + Value elementsPerStride) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap, + destIndices, tagMemRef, tagMap, tagIndices, numElements, stride, + elementsPerStride); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +AffineDmaStartOp AffineDmaStartOp::create( + ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap, + ValueRange srcIndices, Value destMemRef, AffineMap dstMap, + ValueRange destIndices, Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements, Value stride, + Value elementsPerStride) { + return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices, + destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices, + numElements, stride, elementsPerStride); +} + void AffineDmaStartOp::print(OpAsmPrinter &p) { p << " " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); @@ -1919,6 +1946,25 @@ void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, result.addOperands(numElements); } +AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location, + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, + Value numElements) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, tagMemRef, tagMap, tagIndices, numElements); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder, + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, + Value numElements) { + return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices, + numElements); +} + void AffineDmaWaitOp::print(OpAsmPrinter &p) { p << " " << getTagMemRef() << '['; SmallVector operands(getTagIndices()); @@ -2690,8 +2736,8 @@ FailureOr AffineForOp::replaceWithAdditionalYields( rewriter.setInsertionPoint(getOperation()); auto inits = llvm::to_vector(getInits()); inits.append(newInitOperands.begin(), newInitOperands.end()); - AffineForOp newLoop = rewriter.create( - getLoc(), getLowerBoundOperands(), getLowerBoundMap(), + AffineForOp newLoop = AffineForOp::create( + rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(), getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits); // Generate the new yield values and append them to the scf.yield operation. @@ -2833,7 +2879,7 @@ static void buildAffineLoopNestImpl( OpBuilder::InsertionGuard nestedGuard(nestedBuilder); bodyBuilderFn(nestedBuilder, nestedLoc, ivs); } - nestedBuilder.create(nestedLoc); + AffineYieldOp::create(nestedBuilder, nestedLoc); }; // Delegate actual loop creation to the callback in order to dispatch @@ -2848,8 +2894,8 @@ static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { - return builder.create(loc, lb, ub, step, - /*iterArgs=*/ValueRange(), bodyBuilderFn); + return AffineForOp::create(builder, loc, lb, ub, step, + /*iterArgs=*/ValueRange(), bodyBuilderFn); } /// Creates an affine loop from the bounds that may or may not be constants. @@ -2862,9 +2908,9 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, if (lbConst && ubConst) return buildAffineLoopFromConstants(builder, loc, lbConst.value(), ubConst.value(), step, bodyBuilderFn); - return builder.create(loc, lb, builder.getDimIdentityMap(), ub, - builder.getDimIdentityMap(), step, - /*iterArgs=*/ValueRange(), bodyBuilderFn); + return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub, + builder.getDimIdentityMap(), step, + /*iterArgs=*/ValueRange(), bodyBuilderFn); } void mlir::affine::buildAffineLoopNest( @@ -4885,7 +4931,7 @@ struct DropUnitExtentBasis Location loc = delinearizeOp->getLoc(); auto getZero = [&]() -> Value { if (!zero) - zero = rewriter.create(loc, 0); + zero = arith::ConstantIndexOp::create(rewriter, loc, 0); return zero.value(); }; @@ -4908,8 +4954,8 @@ struct DropUnitExtentBasis if (!newBasis.empty()) { // Will drop the leading nullptr from `basis` if there was no outer bound. - auto newDelinearizeOp = rewriter.create( - loc, delinearizeOp.getLinearIndex(), newBasis); + auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, delinearizeOp.getLinearIndex(), newBasis); int newIndex = 0; // Map back the new delinearized indices to the values they replace. for (auto &replacement : replacements) { @@ -4973,12 +5019,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail return success(); } - Value newLinearize = rewriter.create( - linearizeOp.getLoc(), linearizeIns.drop_back(numMatches), + Value newLinearize = affine::AffineLinearizeIndexOp::create( + rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches), ArrayRef{linearizeBasis}.drop_back(numMatches), linearizeOp.getDisjoint()); - auto newDelinearize = rewriter.create( - delinearizeOp.getLoc(), newLinearize, + auto newDelinearize = affine::AffineDelinearizeIndexOp::create( + rewriter, delinearizeOp.getLoc(), newLinearize, ArrayRef{delinearizeBasis}.drop_back(numMatches), delinearizeOp.hasOuterBound()); SmallVector mergedResults(newDelinearize.getResults()); @@ -5050,19 +5096,16 @@ struct SplitDelinearizeSpanningLastLinearizeArg final delinearizeOp, "need at least two elements to form the basis product"); - Value linearizeWithoutBack = - rewriter.create( - linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(), - linearizeOp.getDynamicBasis(), - linearizeOp.getStaticBasis().drop_back(), - linearizeOp.getDisjoint()); - auto delinearizeWithoutSplitPart = - rewriter.create( - delinearizeOp.getLoc(), linearizeWithoutBack, - delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit), - delinearizeOp.hasOuterBound()); - auto delinearizeBack = rewriter.create( - delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(), + Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create( + rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(), + linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(), + linearizeOp.getDisjoint()); + auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create( + rewriter, delinearizeOp.getLoc(), linearizeWithoutBack, + delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit), + delinearizeOp.hasOuterBound()); + auto delinearizeBack = affine::AffineDelinearizeIndexOp::create( + rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(), basis.take_back(elemsToSplit), /*hasOuterBound=*/true); SmallVector results = llvm::to_vector( llvm::concat(delinearizeWithoutSplitPart.getResults(), @@ -5274,7 +5317,7 @@ OpFoldResult computeProduct(Location loc, OpBuilder &builder, } if (auto constant = dyn_cast(result)) return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); - return builder.create(loc, result, dynamicPart).getResult(); + return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult(); } /// If conseceutive outputs of a delinearize_index are linearized with the same @@ -5439,16 +5482,16 @@ struct CancelLinearizeOfDelinearizePortion final newDelinBasis.erase(newDelinBasis.begin() + m.delinStart, newDelinBasis.begin() + m.delinStart + m.length); newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize); - auto newDelinearize = rewriter.create( - m.delinearize.getLoc(), m.delinearize.getLinearIndex(), + auto newDelinearize = AffineDelinearizeIndexOp::create( + rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(), newDelinBasis); // Since there may be other uses of the indices we just merged together, // create a residual affine.delinearize_index that delinearizes the // merged output into its component parts. Value combinedElem = newDelinearize.getResult(m.delinStart); - auto residualDelinearize = rewriter.create( - m.delinearize.getLoc(), combinedElem, basisToMerge); + auto residualDelinearize = AffineDelinearizeIndexOp::create( + rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge); // Swap all the uses of the unaffected delinearize outputs to the new // delinearization so that the old code can be removed if this diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index b1e40d9b289ec..267179f26c52e 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index c11f1bca5d49d..23facb83caebd 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -204,7 +204,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block, void AffineDataCopyGeneration::runOnOperation() { func::FuncOp f = getOperation(); OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create(f.getLoc(), 0); + zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index c0ef28c648ac5..5ad4893c18094 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -60,7 +60,7 @@ static SmallVector computeStrides(Location loc, RewriterBase &rewriter, // Note: basis elements and their products are, definitionally, // non-negative, so `nuw` is justified. if (dynamicPart) - dynamicPart = rewriter.create( + dynamicPart = arith::MulIOp::create(rewriter, loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags); else dynamicPart = dynamicBasis[dynamicIndex - 1]; @@ -76,7 +76,7 @@ static SmallVector computeStrides(Location loc, RewriterBase &rewriter, rewriter.createOrFold(loc, staticPart); if (dynamicPart) stride = - rewriter.create(loc, dynamicPart, stride, ovflags); + arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags); result.push_back(stride); } } @@ -108,19 +108,19 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, Value zero = rewriter.createOrFold(loc, 0); Value initialPart = - rewriter.create(loc, linearIdx, strides.front()); + arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front()); results.push_back(initialPart); auto emitModTerm = [&](Value stride) -> Value { - Value remainder = rewriter.create(loc, linearIdx, stride); - Value remainderNegative = rewriter.create( + Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride); + Value remainderNegative = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, remainder, zero); // If the correction is relevant, this term is <= stride, which is known // to be positive in `index`. Otherwise, while 2 * stride might overflow, // this branch won't be taken, so the risk of `poison` is fine. - Value corrected = rewriter.create( + Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride, arith::IntegerOverflowFlags::nsw); - Value mod = rewriter.create(loc, remainderNegative, + Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative, corrected, remainder); return mod; }; @@ -133,7 +133,7 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, // We know both inputs are positive, so floorDiv == div. // This could potentially be a divui, but it's not clear if that would // cause issues. - Value divided = rewriter.create(loc, modulus, nextStride); + Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride); results.push_back(divided); } @@ -169,7 +169,7 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, // our hands on an `OpOperand&` for the loop invariant counting function. for (auto [stride, idxOp] : llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { - Value scaledIdx = rewriter.create( + Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw); int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); scaledValues.emplace_back(scaledIdx, numHoistableLoops); @@ -186,7 +186,7 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, Value result = scaledValues.front().first; for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) { std::ignore = numHoistableLoops; - result = rewriter.create(loc, result, scaledValue, + result = arith::AddIOp::create(rewriter, loc, result, scaledValue, arith::IntegerOverflowFlags::nsw); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index f28fb3acb7db7..bdaa2a968a0cb 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -89,7 +89,7 @@ static AffineApplyOp createSubApply(RewriterBase &rewriter, auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx); SmallVector rhsOperands = originalOp->getOperands(); canonicalizeMapAndOperands(&rhsMap, &rhsOperands); - return rewriter.create(originalOp.getLoc(), rhsMap, + return AffineApplyOp::create(rewriter, originalOp.getLoc(), rhsMap, rhsOperands); } @@ -161,7 +161,7 @@ FailureOr mlir::affine::decompose(RewriterBase &rewriter, auto current = createSubApply(rewriter, op, subExpressions[0]); for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { Value tmp = createSubApply(rewriter, op, subExpressions[i]); - current = rewriter.create(op.getLoc(), binMap, + current = AffineApplyOp::create(rewriter, op.getLoc(), binMap, ValueRange{current, tmp}); LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 8f42586e5d18a..bfc8a7296f554 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -426,7 +426,7 @@ static Value createPrivateMemRef(AffineForOp forOp, // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the block, because loop nests can be reordered // during the fusion pass. - Value newMemRef = top.create(forOp.getLoc(), newMemRefType); + Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 92cb7075005a3..123d15f2f8d39 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -102,7 +102,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Create and place the alloc right before the 'affine.for' operation. - Value newMemRef = bOuter.create( + Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. @@ -110,7 +110,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { int64_t step = forOp.getStepAsInt(); auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); - auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, + auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap, forOp.getInductionVar()); // replaceAllMemRefUsesWith will succeed unless the forOp body has @@ -132,7 +132,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Insert the dealloc op right after the for loop. bOuter.setInsertionPointAfter(forOp); - bOuter.create(forOp.getLoc(), newMemRef); + memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef); return true; } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 1a266b72d1f8d..9537d3e75c26a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -51,10 +51,10 @@ OpFoldResult affine::materializeComputedBound( "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(tensor::DimOp::create(b, loc, value, *dim)); } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(memref::DimOp::create(b, loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } @@ -76,7 +76,7 @@ OpFoldResult affine::materializeComputedBound( operands[expr.getPosition() + boundMap.getNumDims()]); // General case: build affine.apply op. return static_cast( - b.create(loc, boundMap, operands).getResult()); + affine::AffineApplyOp::create(b, loc, boundMap, operands).getResult()); } FailureOr mlir::affine::reifyShapedValueDimBound( diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 7fae260767e0a..827e1dff9f319 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -905,8 +905,8 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, for (auto resultExpr : map.getResults()) { auto singleResMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); - auto afOp = state.builder.create(op->getLoc(), singleResMap, - mapOperands); + auto afOp = AffineApplyOp::create(state.builder, op->getLoc(), singleResMap, + mapOperands); results.push_back(afOp); } } @@ -961,7 +961,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, auto vecForOp = cast(parentOp); state.builder.setInsertionPointToStart(vecForOp.getBody()); auto newConstOp = - state.builder.create(constOp.getLoc(), vecAttr); + arith::ConstantOp::create(state.builder, constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. state.registerOpVectorReplacement(constOp, newConstOp); @@ -986,8 +986,8 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp, } } - auto newApplyOp = state.builder.create( - applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands); + auto newApplyOp = AffineApplyOp::create( + state.builder, applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands); // Register the new affine.apply result. state.registerValueScalarReplacement(applyOp.getResult(), @@ -1010,7 +1010,7 @@ static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind, auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = - state.builder.create(oldOperand.getLoc(), vecAttr); + arith::ConstantOp::create(state.builder, oldOperand.getLoc(), vecAttr); return newConstOp; } @@ -1062,11 +1062,11 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) { AffineMap ubMap = vecForOp.getUpperBoundMap(); Value ub; if (ubMap.getNumResults() == 1) - ub = state.builder.create(loc, vecForOp.getUpperBoundMap(), - vecForOp.getUpperBoundOperands()); + ub = AffineApplyOp::create(state.builder, loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); else - ub = state.builder.create(loc, vecForOp.getUpperBoundMap(), - vecForOp.getUpperBoundOperands()); + ub = AffineMinOp::create(state.builder, loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); // Then we compute the number of (original) iterations left in the loop. AffineExpr subExpr = state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1); @@ -1080,7 +1080,7 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) { Type maskTy = VectorType::get(state.strategy->vectorSizes, state.builder.getIntegerType(1)); Value mask = - state.builder.create(loc, maskTy, itersLeft); + vector::CreateMaskOp::create(state.builder, loc, maskTy, itersLeft); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" @@ -1123,8 +1123,8 @@ static Operation *vectorizeUniform(Value uniformVal, state.builder.setInsertionPointAfterValue(uniformScalarRepl); auto vectorTy = getVectorType(uniformVal.getType(), state.strategy); - auto bcastOp = state.builder.create(uniformVal.getLoc(), - vectorTy, uniformScalarRepl); + auto bcastOp = BroadcastOp::create(state.builder, uniformVal.getLoc(), + vectorTy, uniformScalarRepl); state.registerValueVectorReplacement(uniformVal, bcastOp); return bcastOp; } @@ -1256,8 +1256,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = state.builder.create( - loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, + auto transfer = vector::TransferReadOp::create( + state.builder, loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, /*padding=*/std::nullopt, permutationMap); // Register replacement for future uses in the scope. @@ -1303,9 +1303,9 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = state.builder.create( - storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices, - permutationMap); + auto transfer = vector::TransferWriteOp::create( + state.builder, storeOp.getLoc(), vectorValue, storeOp.getMemRef(), + indices, permutationMap); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer); // Register replacement for future uses in the scope. @@ -1387,10 +1387,10 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp, } } - auto vecForOp = state.builder.create( - forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), - forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, - vecIterOperands, + auto vecForOp = AffineForOp::create( + state.builder, forOp.getLoc(), forOp.getLowerBoundOperands(), + forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(), + forOp.getUpperBoundMap(), newStep, vecIterOperands, /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { // Make sure we don't create a default terminator in the loop body as // the proper terminator will be added during vectorization. @@ -1512,8 +1512,8 @@ static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, // IterOperands are neutral element vectors. Value neutralVal = cast(newParentOp).getInits()[i]; state.builder.setInsertionPoint(combinerOps.back()); - Value maskedReducedVal = state.builder.create( - reducedVal.getLoc(), mask, reducedVal, neutralVal); + Value maskedReducedVal = arith::SelectOp::create( + state.builder, reducedVal.getLoc(), mask, reducedVal, neutralVal); LLVM_DEBUG( dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " @@ -1865,7 +1865,6 @@ verifyLoopNesting(const std::vector> &loops) { return success(); } - /// External utility to vectorize affine loops in 'loops' using the n-D /// vectorization factors in 'vectorSizes'. By default, each vectorization /// factor is applied inner-to-outer to the loops of each loop nest. diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 0501616ad912c..0f16e84263ec1 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -55,7 +55,7 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, OpBuilder b(forOp); auto lbMap = forOp.getLowerBoundMap(); - auto lb = b.create(forOp.getLoc(), lbMap, + auto lb = AffineApplyOp::create(b, forOp.getLoc(), lbMap, forOp.getLowerBoundOperands()); // For each upper bound expr, get the range. @@ -72,7 +72,7 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, auto bumpMap = AffineMap::get(tripCountMap.getNumDims(), tripCountMap.getNumSymbols(), bumpExprs[i]); bumpValues[i] = - b.create(forOp.getLoc(), bumpMap, tripCountOperands); + AffineApplyOp::create(b, forOp.getLoc(), bumpMap, tripCountOperands); } SmallVector newUbExprs(tripCountMap.getNumResults()); @@ -135,7 +135,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { builder.setInsertionPointToStart(&func.getFunctionBody().front()); else builder.setInsertionPoint(forOp); - auto constOp = builder.create( + auto constOp = arith::ConstantIndexOp::create(builder, forOp.getLoc(), forOp.getConstantLowerBound()); iv.replaceAllUsesWith(constOp); } else { @@ -147,7 +147,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { iv.replaceAllUsesWith(lbOperands[0]); } else { auto affineApplyOp = - builder.create(forOp.getLoc(), lbMap, lbOperands); + AffineApplyOp::create(builder, forOp.getLoc(), lbMap, lbOperands); iv.replaceAllUsesWith(affineApplyOp); } } @@ -182,7 +182,7 @@ static AffineForOp generateShiftedLoop( assert(ubMap.getNumInputs() == ubOperands.size()); auto loopChunk = - b.create(srcForOp.getLoc(), lbOperands, lbMap, ubOperands, + AffineForOp::create(b, srcForOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap, srcForOp.getStepAsInt()); auto loopChunkIV = loopChunk.getInductionVar(); auto srcIV = srcForOp.getInductionVar(); @@ -198,7 +198,7 @@ static AffineForOp generateShiftedLoop( // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV.use_empty() && shift != 0) { - auto ivRemap = bodyBuilder.create( + auto ivRemap = AffineApplyOp::create(bodyBuilder, srcForOp.getLoc(), bodyBuilder.getSingleDimShiftAffineMap( -static_cast(srcForOp.getStepAsInt() * shift)), @@ -434,7 +434,7 @@ static void constructTiledLoopNest(MutableArrayRef origLoops, for (unsigned i = 0; i < width; i++) { OpBuilder b(topLoop); // Loop bounds will be set later. - AffineForOp pointLoop = b.create(loc, 0, 0); + AffineForOp pointLoop = AffineForOp::create(b, loc, 0, 0); pointLoop.getBody()->getOperations().splice( pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -448,7 +448,7 @@ static void constructTiledLoopNest(MutableArrayRef origLoops, for (unsigned i = width; i < 2 * width; i++) { OpBuilder b(topLoop); // Loop bounds will be set later. - AffineForOp tileSpaceLoop = b.create(loc, 0, 0); + AffineForOp tileSpaceLoop = AffineForOp::create(b, loc, 0, 0); tileSpaceLoop.getBody()->getOperations().splice( tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -1063,7 +1063,7 @@ LogicalResult mlir::affine::loopUnrollByFactor( // iv' = iv + i * step auto d0 = b.getAffineDimExpr(0); auto bumpMap = AffineMap::get(1, 0, d0 + i * step); - return b.create(forOp.getLoc(), bumpMap, iv); + return AffineApplyOp::create(b, forOp.getLoc(), bumpMap, iv); }, /*annotateFn=*/annotateFn, /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues); @@ -1227,7 +1227,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp, auto d0 = builder.getAffineDimExpr(0); auto bumpMap = AffineMap::get(1, 0, d0 + i * step); auto ivUnroll = - builder.create(forOp.getLoc(), bumpMap, forOpIV); + AffineApplyOp::create(builder, forOp.getLoc(), bumpMap, forOpIV); operandMaps[i - 1].map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. @@ -1556,7 +1556,7 @@ stripmineSink(AffineForOp forOp, uint64_t factor, for (auto t : targets) { // Insert newForOp before the terminator of `t`. auto b = OpBuilder::atBlockTerminator(t.getBody()); - auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, + auto newForOp = AffineForOp::create(b, t.getLoc(), lbOperands, lbMap, ubOperands, ubMap, originalStep); auto begin = t.getBody()->begin(); // Skip terminator and `newForOp` which is just before the terminator. @@ -1631,9 +1631,9 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { // 1. Store the upper bound of the outermost loop in a variable. Value prev; if (!llvm::hasSingleElement(origUbMap.getResults())) - prev = builder.create(loc, origUbMap, ubOperands); + prev = AffineMinOp::create(builder, loc, origUbMap, ubOperands); else - prev = builder.create(loc, origUbMap, ubOperands); + prev = AffineApplyOp::create(builder, loc, origUbMap, ubOperands); upperBoundSymbols.push_back(prev); // 2. Emit code computing the upper bound of the coalesced loop as product of @@ -1645,15 +1645,15 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { Value upperBound; // If upper bound map has more than one result, take their minimum. if (!llvm::hasSingleElement(origUbMap.getResults())) - upperBound = builder.create(loc, origUbMap, ubOperands); + upperBound = AffineMinOp::create(builder, loc, origUbMap, ubOperands); else - upperBound = builder.create(loc, origUbMap, ubOperands); + upperBound = AffineApplyOp::create(builder, loc, origUbMap, ubOperands); upperBoundSymbols.push_back(upperBound); SmallVector operands; operands.push_back(prev); operands.push_back(upperBound); // Maintain running product of loop upper bounds. - prev = builder.create( + prev = AffineApplyOp::create(builder, loc, AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1, @@ -1683,7 +1683,7 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { SmallVector operands; operands.push_back(previous); operands.push_back(upperBoundSymbols[idx]); - previous = builder.create( + previous = AffineApplyOp::create(builder, loc, AffineMap::get( /*dimCount=*/1, /*symbolCount=*/1, @@ -1700,7 +1700,7 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { SmallVector applyOperands; applyOperands.push_back(previous); applyOperands.push_back(upperBoundSymbols[idx - 1]); - inductionVariable = builder.create( + inductionVariable = AffineApplyOp::create(builder, loc, AffineMap::get( /*dimCount=*/1, /*symbolCount=*/1, @@ -1738,21 +1738,21 @@ void mlir::affine::mapLoopToProcessorIds(scf::ForOp forOp, Value linearIndex = processorId.front(); for (unsigned i = 1, e = processorId.size(); i < e; ++i) { - auto mulApplyOp = b.create( + auto mulApplyOp = AffineApplyOp::create(b, loc, mulMap, ValueRange{linearIndex, numProcessors[i]}); - linearIndex = b.create( + linearIndex = AffineApplyOp::create(b, loc, addMap, ValueRange{mulApplyOp, processorId[i]}); } - auto mulApplyOp = b.create( + auto mulApplyOp = AffineApplyOp::create(b, loc, mulMap, ValueRange{linearIndex, forOp.getStep()}); - Value lb = b.create( + Value lb = AffineApplyOp::create(b, loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()}); forOp.setLowerBound(lb); Value step = forOp.getStep(); for (auto numProcs : numProcessors) - step = b.create(loc, mulMap, ValueRange{numProcs, step}); + step = AffineApplyOp::create(b, loc, mulMap, ValueRange{numProcs, step}); forOp.setStep(step); } @@ -1889,7 +1889,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, auto fastBufOffsetMap = AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]); - auto offset = b.create(loc, fastBufOffsetMap, lbOperands); + auto offset = AffineApplyOp::create(b, loc, fastBufOffsetMap, lbOperands); // Construct the subscript for the fast memref being copied into/from: // x - offset_x. @@ -1916,16 +1916,16 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, if (!isCopyOut) { // Copy in. - auto load = b.create(loc, memref, memIndices); - b.create(loc, load, fastMemRef, fastBufMap, + auto load = AffineLoadOp::create(b, loc, memref, memIndices); + AffineStoreOp::create(b, loc, load, fastMemRef, fastBufMap, fastBufMapOperands); return copyNestRoot; } // Copy out. auto load = - b.create(loc, fastMemRef, fastBufMap, fastBufMapOperands); - b.create(loc, load, memref, memIndices); + AffineLoadOp::create(b, loc, fastMemRef, fastBufMap, fastBufMapOperands); + AffineStoreOp::create(b, loc, load, memref, memIndices); return copyNestRoot; } @@ -1960,7 +1960,7 @@ static LogicalResult generateCopy( auto f = begin->getParentOfType(); OpBuilder topBuilder(f.getFunctionBody()); - Value zeroIndex = topBuilder.create(f.getLoc(), 0); + Value zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0); *sizeInBytes = 0; @@ -2071,7 +2071,7 @@ static LogicalResult generateCopy( memIndices.push_back(zeroIndex); } else { memIndices.push_back( - top.create(loc, indexVal).getResult()); + arith::ConstantIndexOp::create(top, loc, indexVal).getResult()); } } else { // The coordinate for the start location is just the lower bound along the @@ -2085,7 +2085,7 @@ static LogicalResult generateCopy( lbs[d] = lbs[d].replaceDimsAndSymbols( /*dimReplacements=*/{}, symReplacements, lbs[d].getNumSymbols(), /*numResultSyms=*/0); - memIndices.push_back(b.create(loc, lbs[d], regionSymbols)); + memIndices.push_back(AffineApplyOp::create(b, loc, lbs[d], regionSymbols)); } // The fast buffer is copied into at location zero; addressing is relative. bufIndices.push_back(zeroIndex); @@ -2109,7 +2109,7 @@ static LogicalResult generateCopy( // Create the fast memory space buffer just before the 'affine.for' // operation. fastMemRef = - prologue.create(loc, fastMemRefType).getResult(); + memref::AllocOp::create(prologue, loc, fastMemRefType).getResult(); // Record it. fastBufferMap[memref] = fastMemRef; // fastMemRefType is a constant shaped memref. @@ -2126,7 +2126,7 @@ static LogicalResult generateCopy( fastMemRef = fastBufferMap[memref]; } - auto numElementsSSA = top.create(loc, *numElements); + auto numElementsSSA = arith::ConstantIndexOp::create(top, loc, *numElements); Value dmaStride; Value numEltPerDmaStride; @@ -2143,8 +2143,8 @@ static LogicalResult generateCopy( if (!dmaStrideInfos.empty()) { dmaStride = - top.create(loc, dmaStrideInfos[0].stride); - numEltPerDmaStride = top.create( + arith::ConstantIndexOp::create(top, loc, dmaStrideInfos[0].stride); + numEltPerDmaStride = arith::ConstantIndexOp::create(top, loc, dmaStrideInfos[0].numEltPerStride); } } @@ -2175,20 +2175,20 @@ static LogicalResult generateCopy( // Create a tag (single element 1-d memref) for the DMA. auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {}, copyOptions.tagMemorySpace); - auto tagMemRef = prologue.create(loc, tagMemRefType); + auto tagMemRef = memref::AllocOp::create(prologue, loc, tagMemRefType); SmallVector tagIndices({zeroIndex}); auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); if (!region.isWrite()) { // DMA non-blocking read from original buffer to fast buffer. - b.create(loc, memref, memAffineMap, memIndices, + AffineDmaStartOp::create(b, loc, memref, memAffineMap, memIndices, fastMemRef, bufAffineMap, bufIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, dmaStride, numEltPerDmaStride); } else { // DMA non-blocking write from fast buffer to the original memref. - auto op = b.create( + auto op = AffineDmaStartOp::create(b, loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, dmaStride, numEltPerDmaStride); @@ -2199,11 +2199,11 @@ static LogicalResult generateCopy( } // Matching DMA wait to block on completion; tag always has a 0 index. - b.create(loc, tagMemRef, tagAffineMap, zeroIndex, + AffineDmaWaitOp::create(b, loc, tagMemRef, tagAffineMap, zeroIndex, numElementsSSA); // Generate dealloc for the tag. - auto tagDeallocOp = epilogue.create(loc, tagMemRef); + auto tagDeallocOp = memref::DeallocOp::create(epilogue, loc, tagMemRef); if (*nEnd == end && isCopyOutAtEndOfBlock) // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. @@ -2212,7 +2212,7 @@ static LogicalResult generateCopy( // Generate dealloc for the buffer. if (!existingBuf) { - auto bufDeallocOp = epilogue.create(loc, fastMemRef); + auto bufDeallocOp = memref::DeallocOp::create(epilogue, loc, fastMemRef); // When generating pointwise copies, `nEnd' has to be set to deallocOp on // the fast buffer (since it marks the new end insertion point). if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock) @@ -2582,7 +2582,7 @@ AffineForOp mlir::affine::createCanonicalizedAffineForOp( canonicalizeMapAndOperands(&ubMap, &upperOperands); ubMap = removeDuplicateExprs(ubMap); - return b.create(loc, lowerOperands, lbMap, upperOperands, ubMap, + return AffineForOp::create(b, loc, lowerOperands, lbMap, upperOperands, ubMap, step); } @@ -2666,7 +2666,7 @@ static AffineIfOp createSeparationCondition(MutableArrayRef loops, SmallVector setOperands; cst.getValues(0, cst.getNumDimAndSymbolVars(), &setOperands); canonicalizeSetAndOperands(&ifCondSet, &setOperands); - return b.create(loops[0].getLoc(), ifCondSet, setOperands, + return AffineIfOp::create(b, loops[0].getLoc(), ifCondSet, setOperands, /*withElseRegion=*/true); } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 66b3f2a4f93a5..1666825aeac6a 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -58,7 +58,7 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) return nullptr; - auto op = builder.create(loc, lhs, rhs, overflowFlags); + auto op = OpTy::create(builder, loc, lhs, rhs, overflowFlags); return op.getResult(); } @@ -92,13 +92,13 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value remainder = builder.create(loc, lhs, rhs); - Value zeroCst = builder.create(loc, 0); - Value isRemainderNegative = builder.create( + Value remainder = arith::RemSIOp::create(builder, loc, lhs, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value isRemainderNegative = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt, remainder, zeroCst); Value correctedRemainder = - builder.create(loc, remainder, rhs); - Value result = builder.create( + arith::AddIOp::create(builder, loc, remainder, rhs); + Value result = arith::SelectOp::create(builder, loc, isRemainderNegative, correctedRemainder, remainder); return result; } @@ -131,17 +131,17 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value noneCst = builder.create(loc, -1); - Value negative = builder.create( + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value noneCst = arith::ConstantIndexOp::create(builder, loc, -1); + Value negative = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create(loc, noneCst, lhs); + Value negatedDecremented = arith::SubIOp::create(builder, loc, noneCst, lhs); Value dividend = - builder.create(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create(loc, dividend, rhs); + arith::SelectOp::create(builder, loc, negative, negatedDecremented, lhs); + Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs); Value correctedQuotient = - builder.create(loc, noneCst, quotient); - Value result = builder.create(loc, negative, + arith::SubIOp::create(builder, loc, noneCst, quotient); + Value result = arith::SelectOp::create(builder, loc, negative, correctedQuotient, quotient); return result; } @@ -170,26 +170,26 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value oneCst = builder.create(loc, 1); - Value nonPositive = builder.create( + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value oneCst = arith::ConstantIndexOp::create(builder, loc, 1); + Value nonPositive = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create(loc, zeroCst, lhs); - Value decremented = builder.create(loc, lhs, oneCst); + Value negated = arith::SubIOp::create(builder, loc, zeroCst, lhs); + Value decremented = arith::SubIOp::create(builder, loc, lhs, oneCst); Value dividend = - builder.create(loc, nonPositive, negated, decremented); - Value quotient = builder.create(loc, dividend, rhs); + arith::SelectOp::create(builder, loc, nonPositive, negated, decremented); + Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs); Value negatedQuotient = - builder.create(loc, zeroCst, quotient); + arith::SubIOp::create(builder, loc, zeroCst, quotient); Value incrementedQuotient = - builder.create(loc, quotient, oneCst); - Value result = builder.create( + arith::AddIOp::create(builder, loc, quotient, oneCst); + Value result = arith::SelectOp::create(builder, loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { - auto op = builder.create(loc, expr.getValue()); + auto op = arith::ConstantIndexOp::create(builder, loc, expr.getValue()); return op.getResult(); } @@ -299,7 +299,7 @@ static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { // block. IRMapping operandMap; OpBuilder b(hoistOverOp); - auto hoistedIfOp = b.create(ifOp.getLoc(), ifOp.getIntegerSet(), + auto hoistedIfOp = AffineIfOp::create(b, ifOp.getLoc(), ifOp.getIntegerSet(), ifOp.getOperands(), /*elseBlock=*/true); @@ -370,7 +370,7 @@ mlir::affine::affineParallelize(AffineForOp forOp, parallelReductions, [](const LoopReduction &red) { return red.value; })); auto reductionKinds = llvm::to_vector<4>(llvm::map_range( parallelReductions, [](const LoopReduction &red) { return red.kind; })); - AffineParallelOp newPloop = outsideBuilder.create( + AffineParallelOp newPloop = AffineParallelOp::create(outsideBuilder, loc, ValueRange(reducedValues).getTypes(), reductionKinds, llvm::ArrayRef(lowerBoundMap), lowerBoundOperands, llvm::ArrayRef(upperBoundMap), upperBoundOperands, @@ -542,7 +542,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) { SmallVector applyOperands{dimOperands}; applyOperands.push_back(iv); applyOperands.append(symbolOperands.begin(), symbolOperands.end()); - auto apply = builder.create(op.getLoc(), map, applyOperands); + auto apply = AffineApplyOp::create(builder, op.getLoc(), map, applyOperands); iv.replaceAllUsesExcept(apply, apply); } @@ -623,7 +623,7 @@ LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op, AffineValueMap newIvToOldIvMap; AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap); (void)newIvToOldIvMap.canonicalize(); - auto newIV = opBuilder.create( + auto newIV = AffineApplyOp::create(opBuilder, loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands()); op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); return success(); @@ -1188,7 +1188,7 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( for (auto resultExpr : oldMap.getResults()) { auto singleResMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); - auto afOp = builder.create(op->getLoc(), singleResMap, + auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap, oldMapOperands); oldMemRefOperands.push_back(afOp); affineApplyOps.push_back(afOp); @@ -1215,7 +1215,7 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( for (auto resultExpr : indexRemap.getResults()) { auto singleResMap = AffineMap::get( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); - auto afOp = builder.create(op->getLoc(), singleResMap, + auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap, remapOperands); remapOutputs.push_back(afOp); affineApplyOps.push_back(afOp); @@ -1265,7 +1265,7 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // AffineMapAccessInterface, we need to apply the values of `newMapOperands` // to the `newMap` to get the correct indices. for (unsigned i = 0; i < newMemRefRank; i++) { - state.operands.push_back(builder.create( + state.operands.push_back(AffineApplyOp::create(builder, op->getLoc(), AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(), newMap.getResult(i)), @@ -1451,7 +1451,7 @@ void mlir::affine::createAffineComputationSlice( for (auto resultExpr : composedMap.getResults()) { auto singleResMap = AffineMap::get(composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); - sliceOps->push_back(builder.create( + sliceOps->push_back(AffineApplyOp::create(builder, opInst->getLoc(), singleResMap, composedOpOperands)); } @@ -1682,7 +1682,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, // Create ConstantOp for static dimension. auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); inAffineApply.emplace_back( - b.create(allocOp.getLoc(), constantAttr)); + arith::ConstantOp::create(b, allocOp.getLoc(), constantAttr)); } } @@ -1706,7 +1706,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, AffineMap newMap = AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput); Value affineApp = - b.create(allocOp.getLoc(), newMap, inAffineApply); + AffineApplyOp::create(b, allocOp.getLoc(), newMap, inAffineApply); newDynamicSizes.emplace_back(affineApp); } newDimIdx++; @@ -1742,10 +1742,10 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) { newDynamicSizes); // Add the new dynamic sizes in new AllocOp. newAlloc = - b.create(allocOp.getLoc(), newMemRefType, newDynamicSizes, + AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType, newDynamicSizes, allocOp.getAlignmentAttr()); } else { - newAlloc = b.create(allocOp.getLoc(), newMemRefType, + newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType, allocOp.getAlignmentAttr()); } // Replace all uses of the old memref. @@ -1804,10 +1804,10 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) { if (memrefType.isDynamicDim(i)) mapOperands[i] = - b.create(loc, oldSizes[0].getType(), oldSizes[idx++], - b.create(loc, 1)); + arith::SubIOp::create(b, loc, oldSizes[0].getType(), oldSizes[idx++], + arith::ConstantIndexOp::create(b, loc, 1)); else - mapOperands[i] = b.create(loc, oldShape[i] - 1); + mapOperands[i] = arith::ConstantIndexOp::create(b, loc, oldShape[i] - 1); } for (unsigned i = 0, e = oldStrides.size(); i < e; i++) mapOperands[memrefType.getRank() + i] = oldStrides[i]; @@ -1817,7 +1817,7 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { for (unsigned i = 0; i < newRank; i++) { if (!newMemRefType.isDynamicDim(i)) continue; - newSizes.push_back(b.create( + newSizes.push_back(AffineApplyOp::create(b, loc, AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(), oldLayoutMap.getResult(i)), @@ -1825,11 +1825,11 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { } for (unsigned i = 0, e = newSizes.size(); i < e; i++) { newSizes[i] = - b.create(loc, newSizes[i].getType(), newSizes[i], - b.create(loc, 1)); + arith::AddIOp::create(b, loc, newSizes[i].getType(), newSizes[i], + arith::ConstantIndexOp::create(b, loc, 1)); } // Create the new reinterpret_cast op. - auto newReinterpretCast = b.create( + auto newReinterpretCast = memref::ReinterpretCastOp::create(b, loc, newMemRefType, reinterpretCastOp.getSource(), /*offsets=*/ValueRange(), newSizes, /*strides=*/ValueRange(), diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index ebcb951cf3518..e7cbee6b06c45 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -64,7 +64,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index ae2a00cedf6f5..c4c2a043a8024 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -243,7 +244,7 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, Type type, Location loc) { if (isBuildableWith(value, type)) - return builder.create(loc, cast(value)); + return arith::ConstantOp::create(builder, loc, cast(value)); return nullptr; } @@ -256,18 +257,66 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, builder.getIntegerAttr(type, value)); } +arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder, + Location location, + int64_t value, + unsigned width) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, value, width); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder, + int64_t value, + unsigned width) { + return create(builder, builder.getLoc(), value, width); +} + void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, Type type, int64_t value) { arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } +arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder, + Location location, Type type, + int64_t value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, type, value); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder, + Type type, int64_t value) { + return create(builder, builder.getLoc(), type, value); +} + void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, Type type, const APInt &value) { arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } +arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder, + Location location, Type type, + const APInt &value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, type, value); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder, + Type type, + const APInt &value) { + return create(builder, builder.getLoc(), type, value); +} + bool arith::ConstantIntOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isSignlessInteger(); @@ -280,6 +329,23 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, builder.getFloatAttr(type, value)); } +arith::ConstantFloatOp arith::ConstantFloatOp::create(OpBuilder &builder, + Location location, + FloatType type, + const APFloat &value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, type, value); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantFloatOp +arith::ConstantFloatOp::create(ImplicitLocOpBuilder &builder, FloatType type, + const APFloat &value) { + return create(builder, builder.getLoc(), type, value); +} + bool arith::ConstantFloatOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return llvm::isa(constOp.getType()); @@ -292,6 +358,21 @@ void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, builder.getIndexAttr(value)); } +arith::ConstantIndexOp arith::ConstantIndexOp::create(OpBuilder &builder, + Location location, + int64_t value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, value); + auto result = llvm::dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIndexOp +arith::ConstantIndexOp::create(ImplicitLocOpBuilder &builder, int64_t value) { + return create(builder, builder.getLoc(), value); +} + bool arith::ConstantIndexOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isIndex(); @@ -305,7 +386,7 @@ Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc, "type doesn't have a zero representation"); TypedAttr zeroAttr = builder.getZeroAttr(type); assert(zeroAttr && "unsupported type for zero attribute"); - return builder.create(loc, zeroAttr); + return arith::ConstantOp::create(builder, loc, zeroAttr); } //===----------------------------------------------------------------------===// @@ -2335,9 +2416,8 @@ class CmpFIntToFPConst final : public OpRewritePattern { // comparison. rewriter.replaceOpWithNewOp( op, pred, intVal, - rewriter.create( - op.getLoc(), intVal.getType(), - rewriter.getIntegerAttr(intVal.getType(), rhsInt))); + ConstantOp::create(rewriter, op.getLoc(), intVal.getType(), + rewriter.getIntegerAttr(intVal.getType(), rhsInt))); return success(); } }; @@ -2374,10 +2454,10 @@ struct SelectToExtUI : public OpRewritePattern { matchPattern(op.getFalseValue(), m_One())) { rewriter.replaceOpWithNewOp( op, op.getType(), - rewriter.create( - op.getLoc(), op.getCondition(), - rewriter.create( - op.getLoc(), op.getCondition().getType(), 1))); + arith::XOrIOp::create( + rewriter, op.getLoc(), op.getCondition(), + arith::ConstantIntOp::create(rewriter, op.getLoc(), + op.getCondition().getType(), 1))); return success(); } @@ -2440,12 +2520,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { // Constant-fold constant operands over non-splat constant condition. // select %cst_vec, %cst0, %cst1 => %cst2 - if (auto cond = - llvm::dyn_cast_if_present(adaptor.getCondition())) { - if (auto lhs = - llvm::dyn_cast_if_present(adaptor.getTrueValue())) { - if (auto rhs = - llvm::dyn_cast_if_present(adaptor.getFalseValue())) { + if (auto cond = llvm::dyn_cast_if_present( + adaptor.getCondition())) { + if (auto lhs = llvm::dyn_cast_if_present( + adaptor.getTrueValue())) { + if (auto rhs = llvm::dyn_cast_if_present( + adaptor.getFalseValue())) { SmallVector results; results.reserve(static_cast(cond.getNumElements())); auto condVals = llvm::make_range(cond.value_begin(), @@ -2693,7 +2773,7 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, bool useOnlyFiniteValue) { auto attr = getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); - return builder.create(loc, attr); + return arith::ConstantOp::create(builder, loc, attr); } /// Return the value obtained by applying the reduction operation kind @@ -2702,33 +2782,33 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs) { switch (op) { case AtomicRMWKind::addf: - return builder.create(loc, lhs, rhs); + return arith::AddFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::addi: - return builder.create(loc, lhs, rhs); + return arith::AddIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::mulf: - return builder.create(loc, lhs, rhs); + return arith::MulFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::muli: - return builder.create(loc, lhs, rhs); + return arith::MulIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::maximumf: - return builder.create(loc, lhs, rhs); + return arith::MaximumFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::minimumf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxnumf: - return builder.create(loc, lhs, rhs); + return arith::MinimumFOp::create(builder, loc, lhs, rhs); + case AtomicRMWKind::maxnumf: + return arith::MaxNumFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::minnumf: - return builder.create(loc, lhs, rhs); + return arith::MinNumFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::maxs: - return builder.create(loc, lhs, rhs); + return arith::MaxSIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::mins: - return builder.create(loc, lhs, rhs); + return arith::MinSIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::maxu: - return builder.create(loc, lhs, rhs); + return arith::MaxUIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::minu: - return builder.create(loc, lhs, rhs); + return arith::MinUIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::ori: - return builder.create(loc, lhs, rhs); + return arith::OrIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::andi: - return builder.create(loc, lhs, rhs); + return arith::AndIOp::create(builder, loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp index f2e7732e8ea4a..d21da23903544 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -67,7 +67,7 @@ struct SelectOpInterface return state.getMemrefWithUniqueOwnership(builder, value, value.getParentBlock()); - Value ownership = builder.create( + Value ownership = arith::SelectOp::create(builder, op->getLoc(), selectOp.getCondition(), state.getOwnership(selectOp.getTrueValue(), block).getIndicator(), state.getOwnership(selectOp.getFalseValue(), block).getIndicator()); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index afee162053bea..b073a31850678 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -170,10 +170,10 @@ struct SelectOpInterface return failure(); if (trueBuffer.getType() != *targetType) trueBuffer = - rewriter.create(loc, *targetType, trueBuffer); + memref::CastOp::create(rewriter, loc, *targetType, trueBuffer); if (falseBuffer.getType() != *targetType) falseBuffer = - rewriter.create(loc, *targetType, falseBuffer); + memref::CastOp::create(rewriter, loc, *targetType, falseBuffer); } replaceOpWithNewBufferizedOp( diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 62022bfb7df1e..674bd7b8ee201 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -75,7 +75,7 @@ LogicalResult EmulateFloatPattern::matchAndRewrite( for (auto [res, oldType, newType] : llvm::zip_equal( MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { if (oldType != newType) { - auto truncFOp = rewriter.create(loc, oldType, res); + auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res); truncFOp.setFastmath(arith::FastMathFlags::contract); res = truncFOp.getResult(); } @@ -98,7 +98,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( }); converter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); + auto extFOp = arith::ExtFOp::create(b, loc, target, input); extFOp.setFastmath(arith::FastMathFlags::contract); return extFOp; }); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index d5d1559c658ff..44e896a97cb4b 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -72,7 +72,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, // Scalarize the result in case of 1D vectors. if (shape.size() == 1) - return rewriter.create(loc, input, lastOffset); + return vector::ExtractOp::create(rewriter, loc, input, lastOffset); SmallVector offsets(shape.size(), 0); offsets.back() = lastOffset; @@ -80,7 +80,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, sizes.back() = 1; SmallVector strides(shape.size(), 1); - return rewriter.create(loc, input, offsets, + return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets, sizes, strides); } @@ -107,7 +107,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, assert(shape.back() == 1 && "Expected the last vector dim to be x1"); auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType()); - return rewriter.create(loc, newVecTy, input); + return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input); } /// Performs a vector shape cast to append an x1 dimension. If the @@ -122,7 +122,7 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, auto newShape = llvm::to_vector(vecTy.getShape()); newShape.push_back(1); auto newTy = VectorType::get(newShape, vecTy.getElementType()); - return rewriter.create(loc, newTy, input); + return vector::ShapeCastOp::create(rewriter, loc, newTy, input); } /// Inserts the `source` vector slice into the `dest` vector at offset @@ -136,12 +136,12 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, // Handle scalar source. if (isa(source.getType())) - return rewriter.create(loc, source, dest, lastOffset); + return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset); SmallVector offsets(shape.size(), 0); offsets.back() = lastOffset; SmallVector strides(shape.size(), 1); - return rewriter.create(loc, source, dest, + return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest, offsets, strides); } @@ -254,12 +254,12 @@ struct ConvertAddI final : OpConversionPattern { extractLastDimHalves(rewriter, loc, adaptor.getRhs()); auto lowSum = - rewriter.create(loc, lhsElem0, rhsElem0); + arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0); Value overflowVal = - rewriter.create(loc, newElemTy, lowSum.getOverflow()); + arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow()); - Value high0 = rewriter.create(loc, overflowVal, lhsElem1); - Value high = rewriter.create(loc, high0, rhsElem1); + Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1); + Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high}); @@ -293,8 +293,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern { auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - Value resElem0 = rewriter.create(loc, lhsElem0, rhsElem0); - Value resElem1 = rewriter.create(loc, lhsElem1, rhsElem1); + Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); rewriter.replaceOp(op, resultVec); @@ -346,26 +346,26 @@ struct ConvertCmpI final : OpConversionPattern { extractLastDimHalves(rewriter, loc, adaptor.getRhs()); Value lowCmp = - rewriter.create(loc, lowPred, lhsElem0, rhsElem0); + arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0); Value highCmp = - rewriter.create(loc, highPred, lhsElem1, rhsElem1); + arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1); Value cmpResult{}; switch (highPred) { case arith::CmpIPredicate::eq: { - cmpResult = rewriter.create(loc, lowCmp, highCmp); + cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp); break; } case arith::CmpIPredicate::ne: { - cmpResult = rewriter.create(loc, lowCmp, highCmp); + cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp); break; } default: { // Handle inequality checks. - Value highEq = rewriter.create( + Value highEq = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1); cmpResult = - rewriter.create(loc, highEq, lowCmp, highCmp); + arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp); break; } } @@ -401,14 +401,14 @@ struct ConvertMulI final : OpConversionPattern { // Multiplying two i2N integers produces (at most) an i4N result, but // because the calculation of top i2N is not necessary, we omit it. auto mulLowLow = - rewriter.create(loc, lhsElem0, rhsElem0); - Value mulLowHi = rewriter.create(loc, lhsElem0, rhsElem1); - Value mulHiLow = rewriter.create(loc, lhsElem1, rhsElem0); + arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1); + Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0); Value resLow = mulLowLow.getLow(); Value resHi = - rewriter.create(loc, mulLowLow.getHigh(), mulLowHi); - resHi = rewriter.create(loc, resHi, mulHiLow); + arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi); + resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow); Value resultVec = constructResultVector(rewriter, loc, newTy, {resLow, resHi}); @@ -443,10 +443,10 @@ struct ConvertExtSI final : OpConversionPattern { loc, newResultComponentTy, newOperand); Value operandZeroCst = createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0); - Value signBit = rewriter.create( + Value signBit = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst); Value signValue = - rewriter.create(loc, newResultComponentTy, signBit); + arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit); Value resultVec = constructResultVector(rewriter, loc, newTy, {extended, signValue}); @@ -508,7 +508,7 @@ struct ConvertMaxMin final : OpConversionPattern { // Rewrite Max*I/Min*I as compare and select over original operands. Let // the CmpI and Select emulation patterns handle the final legalization. Value cmp = - rewriter.create(loc, CmpPred, op.getLhs(), op.getRhs()); + arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs()); rewriter.replaceOpWithNewOp(op, cmp, op.getLhs(), op.getRhs()); return success(); @@ -587,7 +587,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern { // Sign or zero-extend the result. Let the matching conversion pattern // legalize the extension op. Value underlyingVal = - rewriter.create(loc, narrowTy, adaptor.getIn()); + CastOp::create(rewriter, loc, narrowTy, adaptor.getIn()); rewriter.replaceOpWithNewOp(op, resultType, underlyingVal); return success(); } @@ -616,9 +616,9 @@ struct ConvertSelect final : OpConversionPattern { Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition()); Value resElem0 = - rewriter.create(loc, cond, trueElem0, falseElem0); + arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0); Value resElem1 = - rewriter.create(loc, cond, trueElem1, falseElem1); + arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); rewriter.replaceOp(op, resultVec); @@ -680,33 +680,33 @@ struct ConvertShLI final : OpConversionPattern { Value elemBitWidth = createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); - Value illegalElemShift = rewriter.create( + Value illegalElemShift = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); Value shiftedElem0 = - rewriter.create(loc, lhsElem0, rhsElem0); - Value resElem0 = rewriter.create(loc, illegalElemShift, + arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift, zeroCst, shiftedElem0); - Value cappedShiftAmount = rewriter.create( + Value cappedShiftAmount = arith::SelectOp::create(rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0); Value rightShiftAmount = - rewriter.create(loc, elemBitWidth, cappedShiftAmount); + arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount); Value shiftedRight = - rewriter.create(loc, lhsElem0, rightShiftAmount); + arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount); Value overshotShiftAmount = - rewriter.create(loc, rhsElem0, elemBitWidth); + arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth); Value shiftedLeft = - rewriter.create(loc, lhsElem0, overshotShiftAmount); + arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount); Value shiftedElem1 = - rewriter.create(loc, lhsElem1, rhsElem0); - Value resElem1High = rewriter.create( + arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0); + Value resElem1High = arith::SelectOp::create(rewriter, loc, illegalElemShift, zeroCst, shiftedElem1); - Value resElem1Low = rewriter.create( + Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift, shiftedLeft, shiftedRight); Value resElem1 = - rewriter.create(loc, resElem1Low, resElem1High); + arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); @@ -769,33 +769,33 @@ struct ConvertShRUI final : OpConversionPattern { Value elemBitWidth = createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); - Value illegalElemShift = rewriter.create( + Value illegalElemShift = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); Value shiftedElem0 = - rewriter.create(loc, lhsElem0, rhsElem0); - Value resElem0Low = rewriter.create(loc, illegalElemShift, + arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift, zeroCst, shiftedElem0); Value shiftedElem1 = - rewriter.create(loc, lhsElem1, rhsElem0); - Value resElem1 = rewriter.create(loc, illegalElemShift, + arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0); + Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift, zeroCst, shiftedElem1); - Value cappedShiftAmount = rewriter.create( + Value cappedShiftAmount = arith::SelectOp::create(rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0); Value leftShiftAmount = - rewriter.create(loc, elemBitWidth, cappedShiftAmount); + arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount); Value shiftedLeft = - rewriter.create(loc, lhsElem1, leftShiftAmount); + arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount); Value overshotShiftAmount = - rewriter.create(loc, rhsElem0, elemBitWidth); + arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth); Value shiftedRight = - rewriter.create(loc, lhsElem1, overshotShiftAmount); + arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount); - Value resElem0High = rewriter.create( + Value resElem0High = arith::SelectOp::create(rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft); Value resElem0 = - rewriter.create(loc, resElem0Low, resElem0High); + arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); @@ -832,32 +832,32 @@ struct ConvertShRSI final : OpConversionPattern { // Perform as many ops over the narrow integer type as possible and let the // other emulation patterns convert the rest. Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); - Value signBit = rewriter.create( + Value signBit = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); signBit = dropTrailingX1Dim(rewriter, loc, signBit); // Create a bit pattern of either all ones or all zeros. Then shift it left // to calculate the sign extension bits created by shifting the original // sign bit right. - Value allSign = rewriter.create(loc, oldTy, signBit); + Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit); Value maxShift = createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth); Value numNonSignExtBits = - rewriter.create(loc, maxShift, rhsElem0); + arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0); numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits); numNonSignExtBits = - rewriter.create(loc, oldTy, numNonSignExtBits); + arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits); Value signBits = - rewriter.create(loc, allSign, numNonSignExtBits); + arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits); // Use original arguments to create the right shift. Value shrui = - rewriter.create(loc, op.getLhs(), op.getRhs()); - Value shrsi = rewriter.create(loc, shrui, signBits); + arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs()); + Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits); // Handle shifting by zero. This is necessary when the `signBits` shift is // invalid. - Value isNoop = rewriter.create(loc, arith::CmpIPredicate::eq, + Value isNoop = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero); isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); rewriter.replaceOpWithNewOp(op, isNoop, op.getLhs(), @@ -892,14 +892,14 @@ struct ConvertSubI final : OpConversionPattern { // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where // CARRY is 1 or 0. - Value low = rewriter.create(loc, lhsElem0, rhsElem0); + Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0); // We have a carry if lhsElem0 < rhsElem0. - Value carry0 = rewriter.create( + Value carry0 = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0); - Value carryVal = rewriter.create(loc, newElemTy, carry0); + Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0); - Value high0 = rewriter.create(loc, lhsElem1, carryVal); - Value high = rewriter.create(loc, high0, rhsElem1); + Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal); + Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high}); rewriter.replaceOp(op, resultVec); @@ -933,13 +933,13 @@ struct ConvertSIToFP final : OpConversionPattern { // result or not based on that sign bit. We implement negation by // subtracting from zero. Note that this relies on the the other conversion // patterns to legalize created ops and narrow the bit widths. - Value isNeg = rewriter.create(loc, arith::CmpIPredicate::slt, + Value isNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, in, zeroCst); - Value neg = rewriter.create(loc, zeroCst, in); - Value abs = rewriter.create(loc, isNeg, neg, in); + Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in); + Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in); - Value absResult = rewriter.create(loc, op.getType(), abs); - Value negResult = rewriter.create(loc, absResult); + Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs); + Value negResult = arith::NegFOp::create(rewriter, loc, absResult); rewriter.replaceOpWithNewOp(op, isNeg, negResult, absResult); return success(); @@ -985,13 +985,13 @@ struct ConvertUIToFP final : OpConversionPattern { // // Note 2: We do not strictly need the `hi == 0`, case, but it makes // constant folding easier. - Value hiEqZero = rewriter.create( + Value hiEqZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst); Type resultTy = op.getType(); Type resultElemTy = getElementTypeOrSelf(resultTy); - Value lowFp = rewriter.create(loc, resultTy, lowInt); - Value hiFp = rewriter.create(loc, resultTy, hiInt); + Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt); + Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt); int64_t pow2Int = int64_t(1) << newBitWidth; TypedAttr pow2Attr = @@ -999,10 +999,10 @@ struct ConvertUIToFP final : OpConversionPattern { if (auto vecTy = dyn_cast(resultTy)) pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr); - Value pow2Val = rewriter.create(loc, resultTy, pow2Attr); + Value pow2Val = arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr); - Value hiVal = rewriter.create(loc, hiFp, pow2Val); - Value result = rewriter.create(loc, lowFp, hiVal); + Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val); + Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal); rewriter.replaceOpWithNewOp(op, hiEqZero, lowFp, result); return success(); @@ -1037,22 +1037,22 @@ struct ConvertFPToSI final : OpConversionPattern { // result is UB. TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy); - Value zeroCst = rewriter.create(loc, zeroAttr); + Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr); Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0); // Get the absolute value. One could have used math.absf here, but that // introduces an extra dependency. - Value isNeg = rewriter.create(loc, arith::CmpFPredicate::OLT, + Value isNeg = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst); - Value negInFp = rewriter.create(loc, inFp); + Value negInFp = arith::NegFOp::create(rewriter, loc, inFp); - Value absVal = rewriter.create(loc, isNeg, negInFp, inFp); + Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp); // Defer the absolute value to fptoui. - Value res = rewriter.create(loc, intTy, absVal); + Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal); // Negate the value if < 0 . - Value neg = rewriter.create(loc, zeroCstInt, res); + Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res); rewriter.replaceOpWithNewOp(op, isNeg, neg, res); return success(); @@ -1109,17 +1109,17 @@ struct ConvertFPToUI final : OpConversionPattern { if (auto vecType = dyn_cast(fpTy)) powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr); Value powBitwidthFloatCst = - rewriter.create(loc, powBitwidthAttr); + arith::ConstantOp::create(rewriter, loc, powBitwidthAttr); Value fpDivPowBitwidth = - rewriter.create(loc, inFp, powBitwidthFloatCst); + arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst); Value resHigh = - rewriter.create(loc, newHalfType, fpDivPowBitwidth); + arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth); // Calculate fp - resHigh * 2^N by getting the remainder of the division Value remainder = - rewriter.create(loc, inFp, powBitwidthFloatCst); + arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst); Value resLow = - rewriter.create(loc, newHalfType, remainder); + arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder); Value high = appendX1Dim(rewriter, loc, resHigh); Value low = appendX1Dim(rewriter, loc, resLow); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index dfa01844737c6..7dd0541f3c4b5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -31,10 +31,10 @@ static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter) { auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { - return rewriter.create( + return arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); } /// Create a float constant. @@ -42,11 +42,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter) { auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { - return rewriter.create( + return arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); } /// Creates shapedType using shape from cloneFrom and base type from cloneTo @@ -70,11 +70,11 @@ struct CeilDivUIOpConverter : public OpRewritePattern { Value b = op.getRhs(); Value zero = createConst(loc, a.getType(), 0, rewriter); Value compare = - rewriter.create(loc, arith::CmpIPredicate::eq, a, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero); Value one = createConst(loc, a.getType(), 1, rewriter); - Value minusOne = rewriter.create(loc, a, one); - Value quotient = rewriter.create(loc, minusOne, b); - Value plusOne = rewriter.create(loc, quotient, one); + Value minusOne = arith::SubIOp::create(rewriter, loc, a, one); + Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b); + Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one); rewriter.replaceOpWithNewOp(op, compare, zero, plusOne); return success(); } @@ -99,22 +99,22 @@ struct CeilDivSIOpConverter : public OpRewritePattern { Value zero = createConst(loc, type, 0, rewriter); Value one = createConst(loc, type, 1, rewriter); - Value quotient = rewriter.create(loc, a, b); - Value product = rewriter.create(loc, quotient, b); - Value notEqualDivisor = rewriter.create( + Value quotient = arith::DivSIOp::create(rewriter, loc, a, b); + Value product = arith::MulIOp::create(rewriter, loc, quotient, b); + Value notEqualDivisor = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, a, product); Value aNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, a, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, a, zero); Value bNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, b, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, b, zero); - Value signEqual = rewriter.create( + Value signEqual = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg); Value cond = - rewriter.create(loc, notEqualDivisor, signEqual); + arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual); - Value quotientPlusOne = rewriter.create(loc, quotient, one); + Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one); rewriter.replaceOpWithNewOp(op, cond, quotientPlusOne, quotient); @@ -138,25 +138,25 @@ struct FloorDivSIOpConverter : public OpRewritePattern { Value a = op.getLhs(); Value b = op.getRhs(); - Value quotient = rewriter.create(loc, a, b); - Value product = rewriter.create(loc, quotient, b); - Value notEqualDivisor = rewriter.create( + Value quotient = arith::DivSIOp::create(rewriter, loc, a, b); + Value product = arith::MulIOp::create(rewriter, loc, quotient, b); + Value notEqualDivisor = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, a, product); Value zero = createConst(loc, type, 0, rewriter); Value aNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, a, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, a, zero); Value bNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, b, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, b, zero); - Value signOpposite = rewriter.create( + Value signOpposite = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg); Value cond = - rewriter.create(loc, notEqualDivisor, signOpposite); + arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite); Value minusOne = createConst(loc, type, -1, rewriter); Value quotientMinusOne = - rewriter.create(loc, quotient, minusOne); + arith::AddIOp::create(rewriter, loc, quotient, minusOne); rewriter.replaceOpWithNewOp(op, cond, quotientMinusOne, quotient); @@ -174,7 +174,7 @@ struct MaxMinIOpConverter : public OpRewritePattern { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - Value cmp = rewriter.create(op.getLoc(), pred, lhs, rhs); + Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs); rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); return success(); } @@ -195,11 +195,11 @@ struct MaximumMinimumFOpConverter : public OpRewritePattern { static_assert(pred == arith::CmpFPredicate::UGT || pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); + Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs); + Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs); // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, + Value isNaN = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNO, rhs, rhs); rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); @@ -221,11 +221,11 @@ struct MaxNumMinNumFOpConverter : public OpRewritePattern { static_assert(pred == arith::CmpFPredicate::UGT || pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); + Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs); + Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs); // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'. - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, + Value isNaN = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNO, lhs, lhs); rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); @@ -250,12 +250,12 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern { Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); - Value bitcast = b.create(i16Ty, operand); - Value exti = b.create(i32Ty, bitcast); + Value bitcast = arith::BitcastOp::create(b, i16Ty, operand); + Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast); Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); - Value shl = b.create(exti, c16); - Value result = b.create(resultTy, shl); + Value shl = arith::ShLIOp::create(b, exti, c16); + Value result = arith::BitcastOp::create(b, resultTy, shl); rewriter.replaceOp(op, result); return success(); @@ -299,7 +299,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { // exponent bits, that simple truncation is the desired outcome for // infinities. Value isNan = - b.create(arith::CmpFPredicate::UNE, operand, operand); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand); // Constant used to make the rounding bias. Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); // Constant used to generate a quiet NaN. @@ -308,30 +308,30 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); // Reinterpret the input f32 value as bits. - Value bitcast = b.create(i32Ty, operand); + Value bitcast = arith::BitcastOp::create(b, i32Ty, operand); // Read bit 16 as a value in {0,1}. Value bit16 = - b.create(b.create(bitcast, c16), c1); + arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1); // Determine the rounding bias to add as either 0x7fff or 0x8000 depending // on bit 16, implementing the tie-breaking "to nearest even". - Value roundingBias = b.create(bit16, c7FFF); + Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF); // Add the rounding bias. Generally we want this to be added to the // mantissa, but nothing prevents this to from carrying into the exponent // bits, which would feel like a bug, but this is the magic trick here: // when that happens, the mantissa gets reset to zero and the exponent // gets incremented by the carry... which is actually exactly what we // want. - Value biased = b.create(bitcast, roundingBias); + Value biased = arith::AddIOp::create(b, bitcast, roundingBias); // Now that the rounding-bias has been added, truncating the low bits // yields the correctly rounded result. - Value biasedAndShifted = b.create(biased, c16); + Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16); Value normalCaseResultI16 = - b.create(i16Ty, biasedAndShifted); + arith::TruncIOp::create(b, i16Ty, biasedAndShifted); // Select either the above-computed result, or a quiet NaN constant // if the input was NaN. Value select = - b.create(isNan, c7FC0I16, normalCaseResultI16); - Value result = b.create(resultTy, select); + arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16); + Value result = arith::BitcastOp::create(b, resultTy, select); rewriter.replaceOp(op, result); return success(); } @@ -384,7 +384,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); Type i4Ty = cloneToShapedType(operandTy, b.getI4Type()); Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); - Value i4Bits = b.create(i4Ty, operand); + Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand); Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter); Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter); @@ -393,38 +393,38 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern { // Set last Exponent bit and Mantissa. Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter); - Value bits1To24 = b.create(i4Bits, c0x2); + Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2); Value isHalf = - b.create(arith::CmpIPredicate::eq, i4Bits, c0x1); - bits1To24 = b.create(isHalf, c0x0, bits1To24); - bits1To24 = b.create(i32Ty, bits1To24); - bits1To24 = b.create(bits1To24, c0x00000014); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1); + bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24); + bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24); + bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014); // Set first 7 bits of Exponent. Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter); Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter); Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter); Value useLargerExp = - b.create(arith::CmpIPredicate::uge, i4Bits, c0x4); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4); Value bits25To31 = - b.create(useLargerExp, highExpBits, lowExpBits); + arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits); Value zeroExp = - b.create(arith::CmpIPredicate::eq, i4Bits, c0x0); - bits25To31 = b.create(zeroExp, zeroExpBits, bits25To31); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0); + bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31); // Set sign. Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter); Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter); Value negative = - b.create(arith::CmpIPredicate::uge, i4Bits, c0x8); - Value bit32 = b.create(negative, c0x80000000, zeroExpBits); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8); + Value bit32 = arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits); // Add segments together. - Value bits1To31 = b.create(bits1To24, bits25To31); - Value bits1To32 = b.create(bits1To31, bit32); - Value result = b.create(f32Ty, bits1To32); + Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31); + Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32); + Value result = arith::BitcastOp::create(b, f32Ty, bits1To32); if (!isa(resultETy)) - result = b.create(resultTy, result); + result = arith::TruncFOp::create(b, resultTy, result); rewriter.replaceOp(op, result); return success(); @@ -450,25 +450,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern { Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); - Value bitcast = b.create(i8Ty, operand); + Value bitcast = arith::BitcastOp::create(b, i8Ty, operand); // create constants for NaNs Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter); Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); - Value exti = b.create(i32Ty, bitcast); - Value f32Bits = b.create(exti, cF32MantissaWidth); + Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast); + Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth); Value isNan = - b.create(arith::CmpIPredicate::eq, bitcast, cF8NaN); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN); // select for NaNs - f32Bits = b.create(isNan, cF32NaN, f32Bits); - Value result = b.create(f32Ty, f32Bits); + f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits); + Value result = arith::BitcastOp::create(b, f32Ty, f32Bits); if (resultETy.getIntOrFloatBitWidth() < 32) { - result = b.create(resultTy, result, nullptr, + result = arith::TruncFOp::create(b, resultTy, result, nullptr, op.getFastmathAttr()); } else if (resultETy.getIntOrFloatBitWidth() > 32) { - result = b.create(resultTy, result, op.getFastmathAttr()); + result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr()); } rewriter.replaceOp(op, result); return success(); @@ -521,7 +521,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); if (!isa(operandETy)) - operand = b.create(f32Ty, operand); + operand = arith::ExtFOp::create(b, f32Ty, operand); if (!isa(resultETy)) return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN"); @@ -535,65 +535,65 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern { // Step 0: Clamp to bounds. Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter); Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter); - Value operandClamped = b.create(cHigherBound, operand); - operandClamped = b.create(cLowerBound, operandClamped); - Value f32Bits = b.create(i32Ty, operandClamped); + Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand); + operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped); + Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped); // Step 1: Set sign bit. Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23 - Value f32Sign = b.create(f32Bits, cF32ExpManWidth); - Value f4Sign = b.create(i4Ty, f32Sign); - Value f4Bits = b.create(f4Sign, c0x3); + Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth); + Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign); + Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3); // Step 2: Convert exponent by adjusting bias. Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter); Value cF4MantissaWidth = c0x1; // 1 Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23 - Value f32SignExp = b.create(f32Bits, cF32MantissaWidth); + Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth); Value biasAdjustedSignExp = - b.create(f32SignExp, biasAdjustment); - Value f4Exp = b.create(i4Ty, biasAdjustedSignExp); - f4Exp = b.create(f4Exp, cF4MantissaWidth); - f4Bits = b.create(f4Bits, f4Exp); + arith::SubIOp::create(b, f32SignExp, biasAdjustment); + Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp); + f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth); + f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp); // Step 3: Set mantissa to first bit. Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter); - Value man1Bit = b.create(f32Bits, cF32FirstBitMask); - man1Bit = b.create(man1Bit, c0x00000016); - Value f4Man = b.create(i4Ty, man1Bit); - f4Bits = b.create(f4Bits, f4Man); + Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask); + man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016); + Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit); + f4Bits = arith::AddIOp::create(b, f4Bits, f4Man); // Step 4: Special consideration for conversion to 0.5. Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter); - Value f8Exp = b.create(i8Ty, biasAdjustedSignExp); + Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp); Value isSubnormal = - b.create(arith::CmpIPredicate::sle, f8Exp, c0x00); + arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00); Value isNegOneExp = - b.create(arith::CmpIPredicate::eq, f8Exp, c0xff); - Value man23Bits = b.create(f32Bits, cF32MantissaMask); - Value isNonZeroMan = b.create(arith::CmpIPredicate::ugt, + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff); + Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask); + Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt, man23Bits, zeroExpBits); - Value roundToHalf = b.create(isNegOneExp, isNonZeroMan); + Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan); Value isZeroExp = - b.create(arith::CmpIPredicate::eq, f8Exp, c0x00); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00); Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter); Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter); Value subResult = - b.create(isSubnormal, subnormalF4Bits, f4Bits); - subResult = b.create(roundToHalf, halfF4Bits, subResult); - f4Bits = b.create(isZeroExp, f4Bits, subResult); + arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits); + subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult); + f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult); // Step 5: Round up if necessary. Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter); Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000... - Value man22Bits = b.create(f32Bits, cF32Last22BitMask); + Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask); Value shouldRound = - b.create(arith::CmpIPredicate::uge, man22Bits, cRound); - shouldRound = b.create(shouldRound, isSubnormal); - Value roundedF4Bits = b.create(f4Bits, c0x1); - f4Bits = b.create(shouldRound, roundedF4Bits, f4Bits); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound); + shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal); + Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1); + f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits); - Value result = b.create(resultTy, f4Bits); + Value result = arith::BitcastOp::create(b, resultTy, f4Bits); rewriter.replaceOp(op, result); return success(); } @@ -628,16 +628,16 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); if (operandETy.getIntOrFloatBitWidth() < 32) { - operand = b.create(f32Ty, operand, op.getFastmathAttr()); + operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr()); } else if (operandETy.getIntOrFloatBitWidth() > 32) { - operand = b.create( + operand = arith::TruncFOp::create(b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); } - Value f32Bits = b.create(i32Ty, operand); + Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); - Value f32SignExp = b.create(f32Bits, cF32MantissaWidth); - Value exp8Bits = b.create(i8Ty, f32SignExp); - Value result = b.create(resultTy, exp8Bits); + Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth); + Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp); + Value result = arith::BitcastOp::create(b, resultTy, exp8Bits); rewriter.replaceOp(op, result); return success(); } @@ -656,7 +656,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern { if (scaleETy.getIntOrFloatBitWidth() >= 16) { scaleETy = b.getF8E8M0Type(); scaleTy = cloneToShapedType(scaleTy, scaleETy); - scaleOperand = b.create(scaleTy, scaleOperand, nullptr, + scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr, op.getFastmathAttr()); } if (!llvm::isa(scaleETy)) { @@ -668,11 +668,11 @@ struct ScalingExtFOpConverter : public OpRewritePattern { // extf on scale will essentially create floating point number // of type resulTy that is 2^scale and will also propagate NaNs Value scaleExt = - b.create(resultTy, scaleOperand, op.getFastmathAttr()); + arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr()); Value inputExt = - b.create(resultTy, inputOperand, op.getFastmathAttr()); + arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr()); Value result = - b.create(inputExt, scaleExt, op.getFastmathAttr()); + arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr()); rewriter.replaceOp(op, result); return success(); } @@ -697,7 +697,7 @@ struct ScalingTruncFOpConverter if (scaleETy.getIntOrFloatBitWidth() >= 16) { scaleETy = b.getF8E8M0Type(); scaleTy = cloneToShapedType(scaleTy, scaleETy); - scaleOperand = b.create(scaleTy, scaleOperand, nullptr, + scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr, op.getFastmathAttr()); } if (!llvm::isa(scaleETy)) { @@ -710,10 +710,10 @@ struct ScalingTruncFOpConverter // this will create a floating point number of type // inputTy that is 2^scale and will also propagate NaNs scaleOperand = - b.create(inputTy, scaleOperand, op.getFastmathAttr()); - Value result = b.create(inputOperand, scaleOperand, + arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr()); + Value result = arith::DivFOp::create(b, inputOperand, scaleOperand, op.getFastmathAttr()); - Value resultCast = b.create( + Value resultCast = arith::TruncFOp::create(b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); rewriter.replaceOp(op, resultCast); return success(); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index f2f93883eb2b7..777ff0ecaa314 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -305,18 +305,18 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, if (isa(srcElemType) || isa(dstElemType)) { if (castKind == CastKind::Signed) - return builder.create(loc, dstType, src); - return builder.create(loc, dstType, src); + return arith::IndexCastOp::create(builder, loc, dstType, src); + return arith::IndexCastUIOp::create(builder, loc, dstType, src); } auto srcInt = cast(srcElemType); auto dstInt = cast(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) - return builder.create(loc, dstType, src); + return arith::TruncIOp::create(builder, loc, dstType, src); if (castKind == CastKind::Signed) - return builder.create(loc, dstType, src); - return builder.create(loc, dstType, src); + return arith::ExtSIOp::create(builder, loc, dstType, src); + return arith::ExtUIOp::create(builder, loc, dstType, src); } struct NarrowElementwise final : OpTraitRewritePattern { diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index 5fb7953f93700..904aab33345f0 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -23,7 +23,7 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, std::function buildExpr = [&](AffineExpr e) -> Value { switch (e.getKind()) { case AffineExprKind::Constant: - return b.create(loc, + return ConstantIndexOp::create(b, loc, cast(e).getValue()); case AffineExprKind::DimId: return operands[cast(e).getPosition()]; @@ -32,27 +32,27 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, map.getNumDims()]; case AffineExprKind::Add: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), + return AddIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mul: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), + return MulIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::FloorDiv: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), + return DivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::CeilDiv: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), + return CeilDivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mod: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), + return RemSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } } @@ -89,10 +89,10 @@ FailureOr mlir::arith::reifyValueBound( "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(tensor::DimOp::create(b, loc, value, *dim)); } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(memref::DimOp::create(b, loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index 62d137a4cfb0e..10c5b50cb771a 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -84,7 +84,7 @@ struct ConstantShardingInterface cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable), sharding)); auto newValue = value.resizeSplat(newType); - auto newOp = builder.create(op->getLoc(), newType, newValue); + auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue); spmdizationMap.map(op->getResult(0), newOp.getResult()); spmdizationMap.map(op, newOp.getOperation()); } else { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 3cd8684878a11..6e9cc14c6ada5 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -68,7 +68,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, // dynamism. Value indexGroupSize = cast(inputShape[inputIndex]); Value indexGroupStaticSizesProduct = - b.create(loc, indexGroupStaticSizesProductInt); + arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt); Value dynamicDimSize = b.createOrFold( loc, indexGroupSize, indexGroupStaticSizesProduct); outputShapeValues.push_back(dynamicDimSize); @@ -105,7 +105,7 @@ Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, if (auto value = dyn_cast_if_present(ofr)) return value; auto attr = cast(cast(ofr)); - return b.create( + return arith::ConstantOp::create(b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); } @@ -114,7 +114,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, if (auto value = dyn_cast_if_present(ofr)) return value; auto attr = cast(cast(ofr)); - return b.create(loc, attr.getValue().getSExtValue()); + return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue()); } Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, @@ -125,7 +125,7 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, bool targetIsIndex = targetType.isIndex(); bool valueIsIndex = value.getType().isIndex(); if (targetIsIndex ^ valueIsIndex) - return b.create(loc, targetType, value); + return arith::IndexCastOp::create(b, loc, targetType, value); auto targetIntegerType = dyn_cast(targetType); auto valueIntegerType = dyn_cast(value.getType()); @@ -134,8 +134,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) - return b.create(loc, targetIntegerType, value); - return b.create(loc, targetIntegerType, value); + return arith::ExtSIOp::create(b, loc, targetIntegerType, value); + return arith::TruncIOp::create(b, loc, targetIntegerType, value); } static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, @@ -143,21 +143,21 @@ static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, // If operand is floating point, cast directly to the int type. if (isa(operand.getType())) { if (isUnsigned) - return b.create(toType, operand); - return b.create(toType, operand); + return arith::FPToUIOp::create(b, toType, operand); + return arith::FPToSIOp::create(b, toType, operand); } // Cast index operands directly to the int type. if (operand.getType().isIndex()) - return b.create(toType, operand); + return arith::IndexCastOp::create(b, toType, operand); if (auto fromIntType = dyn_cast(operand.getType())) { // Either extend or truncate. if (toType.getWidth() > fromIntType.getWidth()) { if (isUnsigned) - return b.create(toType, operand); - return b.create(toType, operand); + return arith::ExtUIOp::create(b, toType, operand); + return arith::ExtSIOp::create(b, toType, operand); } if (toType.getWidth() < fromIntType.getWidth()) - return b.create(toType, operand); + return arith::TruncIOp::create(b, toType, operand); return operand; } @@ -170,14 +170,14 @@ static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, // Note that it is unclear how to cast from BF16<->FP16. if (isa(operand.getType())) { if (isUnsigned) - return b.create(toType, operand); - return b.create(toType, operand); + return arith::UIToFPOp::create(b, toType, operand); + return arith::SIToFPOp::create(b, toType, operand); } if (auto fromFpTy = dyn_cast(operand.getType())) { if (toType.getWidth() > fromFpTy.getWidth()) - return b.create(toType, operand); + return arith::ExtFOp::create(b, toType, operand); if (toType.getWidth() < fromFpTy.getWidth()) - return b.create(toType, operand); + return arith::TruncFOp::create(b, toType, operand); return operand; } @@ -190,18 +190,18 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, if (auto fromComplexType = dyn_cast(operand.getType())) { if (isa(targetType.getElementType()) && isa(fromComplexType.getElementType())) { - Value real = b.create(operand); - Value imag = b.create(operand); + Value real = complex::ReOp::create(b, operand); + Value imag = complex::ImOp::create(b, operand); Type targetETy = targetType.getElementType(); if (targetType.getElementType().getIntOrFloatBitWidth() < fromComplexType.getElementType().getIntOrFloatBitWidth()) { - real = b.create(targetETy, real); - imag = b.create(targetETy, imag); + real = arith::TruncFOp::create(b, targetETy, real); + imag = arith::TruncFOp::create(b, targetETy, imag); } else { - real = b.create(targetETy, real); - imag = b.create(targetETy, imag); + real = arith::ExtFOp::create(b, targetETy, real); + imag = arith::ExtFOp::create(b, targetETy, imag); } - return b.create(targetType, real, imag); + return complex::CreateOp::create(b, targetType, real, imag); } } @@ -210,27 +210,27 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); Value from = operand; if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { - from = b.create(toFpTy, from); + from = arith::ExtFOp::create(b, toFpTy, from); } if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { - from = b.create(toFpTy, from); + from = arith::TruncFOp::create(b, toFpTy, from); } - Value zero = b.create( + Value zero = mlir::arith::ConstantFloatOp::create(b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); - return b.create(targetType, from, zero); + return complex::CreateOp::create(b, targetType, from, zero); } if (isa(operand.getType())) { FloatType toFpTy = cast(targetType.getElementType()); Value from = operand; if (isUnsigned) { - from = b.create(toFpTy, from); + from = arith::UIToFPOp::create(b, toFpTy, from); } else { - from = b.create(toFpTy, from); + from = arith::SIToFPOp::create(b, toFpTy, from); } - Value zero = b.create( + Value zero = mlir::arith::ConstantFloatOp::create(b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); - return b.create(targetType, from, zero); + return complex::CreateOp::create(b, targetType, from, zero); } return {}; @@ -278,7 +278,7 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, attr = SplatElementsAttr::get(vecTy, value); } - return builder.create(loc, attr); + return arith::ConstantOp::create(builder, loc, attr); } Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, @@ -310,35 +310,35 @@ Type mlir::getType(OpFoldResult ofr) { } Value ArithBuilder::_and(Value lhs, Value rhs) { - return b.create(loc, lhs, rhs); + return arith::AndIOp::create(b, loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs, ovf); + return arith::AddFOp::create(b, loc, lhs, rhs); + return arith::AddIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::sub(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs, ovf); + return arith::SubFOp::create(b, loc, lhs, rhs); + return arith::SubIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs, ovf); + return arith::MulFOp::create(b, loc, lhs, rhs); + return arith::MulIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); - return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); - return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { - return b.create(loc, cmp, lhs, rhs); + return arith::SelectOp::create(b, loc, cmp, lhs, rhs); } namespace mlir::arith { @@ -349,7 +349,7 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef values) { Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, Type resultType) { - Value one = builder.create(loc, resultType, + Value one = ConstantOp::create(builder, loc, resultType, builder.getOneAttr(resultType)); ArithBuilder arithBuilder(builder, loc); return std::accumulate( diff --git a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp index 1ded35af78052..5f9cd148f5e0a 100644 --- a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp +++ b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp index d07e6a52d8b5f..f0b7a8a4b0954 100644 --- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp +++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" using namespace mlir; diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp index 7180884c77e98..0d5640b202fb3 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp @@ -231,7 +231,7 @@ class LowerContractionToNeonI8MMPattern // Initial accumulator for the final result. This is the un-tiled result if // tiling is done. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); SmallVector unrolledSize = *op.getShapeForUnroll(); @@ -283,7 +283,7 @@ class LowerContractionToNeonI8MMPattern if (isVecmat) { auto expandForSMMLA = [&](Value tiledOperand, VectorType expandedTypeType) { - auto emptyOperand = rewriter.create( + auto emptyOperand = arith::ConstantOp::create(rewriter, loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); SmallVector offsets( cast(emptyOperand.getType()).getRank(), 0); @@ -300,7 +300,7 @@ class LowerContractionToNeonI8MMPattern // using the instruction for unsigned by signed multiplication with // reversed operands. if (mmlaOp == MMLA::MixedSwapped) - tiledAcc = rewriter.create( + tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc, ArrayRef({1, 0})); // Collapse tiled operands to 1D vectors required by smmla intrinsic @@ -333,7 +333,7 @@ class LowerContractionToNeonI8MMPattern // Because of the reversed operands the result is obtained transposed. // Transpose it back, if (mmlaOp == MMLA::MixedSwapped) - tiledRes = rewriter.create( + tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes, ArrayRef({1, 0})); // With vecmat, only one row of tiled ACC can be inserted into the final diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp index cb3a665844872..2d8903df00055 100644 --- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index 1f7305a5f8141..a4aebb585c3de 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -76,21 +76,21 @@ scf::ForOp createLoopOverTileSlices( PatternRewriter &rewriter, Location loc, Value initTile, std::function makeLoopBody) { OpBuilder::InsertionGuard g(rewriter); - auto step = rewriter.create(loc, 1); - auto minTileSlices = rewriter.create( + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create(rewriter, loc, llvm::cast(initTile.getType()).getDimSize(0)); auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - auto forOp = rewriter.create(loc, lowerBound, numTileSlices, step, + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); Value nextTile = makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(), /*currentTile=*/forOp.getRegionIterArg(0)); - rewriter.create(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 23f2c2bf65e47..4d03bbdabdf09 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -136,7 +136,7 @@ class OuterProductFusion2Way auto loc = op.getLoc(); auto packInputs = [&](Value lhs, Value rhs) { - return rewriter.create(loc, lhs, rhs); + return vector::InterleaveOp::create(rewriter, loc, lhs, rhs); }; auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), @@ -284,7 +284,7 @@ class OuterProductFusion4Way auto loc = op.getLoc(); auto packInputs = [&](Value lhs, Value rhs) { - return rewriter.create(loc, lhs, rhs); + return vector::InterleaveOp::create(rewriter, loc, lhs, rhs); }; auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), @@ -456,7 +456,7 @@ struct SwapVectorExtractOfArithExtend Value extendSource = extendOp->getOperand(0); // Create new extract from source of extend. - Value newExtract = rewriter.create( + Value newExtract = vector::ExtractOp::create(rewriter, loc, extendSource, extractOp.getMixedPosition()); // Extend new extract to original result type. @@ -503,7 +503,7 @@ struct SwapVectorScalableExtractOfArithExtend // Create new extract from source of extend. VectorType extractResultVectorType = resultType.clone(extendSourceVectorType.getElementType()); - Value newExtract = rewriter.create( + Value newExtract = vector::ScalableExtractOp::create(rewriter, loc, extractResultVectorType, extendSource, extractOp.getPos()); // Extend new extract to original result type. diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index e6c9adba62f34..2bea9de83143a 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -212,7 +212,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { rewriter.setInsertionPointToEnd(source); - rewriter.create(loc, dest, args); + cf::BranchOp::create(rewriter, loc, dest, args); }; for (auto condBranch : worklist) { @@ -255,7 +255,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter, for (OpOperand &operand : terminator->getOpOperands()) { if (isValidSMETileVectorType(operand.get().getType())) { auto copy = - rewriter.create(terminator->getLoc(), operand.get()); + CopyTileOp::create(rewriter, terminator->getLoc(), operand.get()); rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); } } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 1e8e1265affa0..3362d0ce13533 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -82,13 +82,13 @@ SmallVector addConstantScalableOffset(OpBuilder &builder, Location loc, ValueRange indices, ArrayRef scalableOffsets) { - auto vscale = builder.create(loc); + auto vscale = vector::VectorScaleOp::create(builder, loc); return llvm::map_to_vector( llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value { auto [index, base] = pair; - auto offset = builder.create( - loc, builder.create(loc, base), vscale); - return builder.create(loc, index, offset); + auto offset = arith::MulIOp::create(builder, + loc, arith::ConstantIndexOp::create(builder, loc, base), vscale); + return arith::AddIOp::create(builder, loc, index, offset); }); } @@ -132,7 +132,7 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, // from the mask operands to get the parameters for this sub-tile. auto smeTileMaskDims = addConstantScalableOffset( builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); - auto smeTileCreateMask = builder.create( + auto smeTileCreateMask = vector::CreateMaskOp::create(builder, loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); return smeTileCreateMask.getResult(); } @@ -190,7 +190,7 @@ struct LegalizeArithConstantOpsByDecomposition auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); auto tileCount = getNumberOfSMETilesForVectorType(vectorType); - auto tileSplat = rewriter.create( + auto tileSplat = arith::ConstantOp::create(rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); SmallVector repl(tileCount, tileSplat); rewriter.replaceOpWithMultiple(constantOp, {repl}); @@ -237,11 +237,11 @@ struct LegalizeVectorOuterProductOpsByDecomposition decomposeToSMETiles(rewriter, vectorType, smeTileType))) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); - auto lhs = rewriter.create( + auto lhs = vector::ScalableExtractOp::create(rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row); - auto rhs = rewriter.create( + auto rhs = vector::ScalableExtractOp::create(rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col); - auto smeOuterProduct = rewriter.create( + auto smeOuterProduct = vector::OuterProductOp::create(rewriter, loc, smeTileType, lhs, rhs, !accSMETiles.empty() ? accSMETiles[index] : Value{}, outerProductOp.getKind()); @@ -314,7 +314,7 @@ struct LegalizeTransferReadOpsByDecomposition for (SMESubTile smeTile : decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); - auto smeRead = rewriter.create( + auto smeRead = vector::TransferReadOp::create(rewriter, loc, smeTileType, readOp.getBase(), getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, @@ -363,7 +363,7 @@ struct LegalizeTransferWriteOpsByDecomposition for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( rewriter, vectorType, smeTileType, transposed))) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); - auto smeWrite = rewriter.create( + auto smeWrite = vector::TransferWriteOp::create(rewriter, loc, inputSMETiles[index], destTensorOrMemref, getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); @@ -456,11 +456,11 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop VectorType::get(minTileSlices, rewriter.getI1Type(), true); // Create loop over all tile slices. - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upperBound = createVscaleMultiple(minTileSlices); - auto step = rewriter.create(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto storeLoop = - rewriter.create(loc, lowerBound, upperBound, step); + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); rewriter.setInsertionPointToStart(storeLoop.getBody()); // For each sub-tile of the multi-tile `vectorType`. @@ -474,29 +474,29 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop // The current slice of `vectorType` we are processing. auto sliceIndex = - rewriter.create(loc, tileRow, tileSliceIndex); + arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex); // Where in the destination memref the current slice will be stored. - auto storeRow = rewriter.create(loc, sliceIndex, + auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex, writeOp.getIndices()[0]); auto storeCol = - rewriter.create(loc, tileCol, writeOp.getIndices()[1]); + arith::AddIOp::create(rewriter, loc, tileCol, writeOp.getIndices()[1]); // Extract the mask for the current slice. Value sliceMask = nullptr; if (mask) { - sliceMask = rewriter.create( + sliceMask = vector::ExtractOp::create(rewriter, loc, mask, OpFoldResult(sliceIndex)); if (sliceMaskType != sliceMask.getType()) - sliceMask = rewriter.create( + sliceMask = vector::ScalableExtractOp::create(rewriter, loc, sliceMaskType, sliceMask, smeTile.col); } // Extract and store the current slice. Value tile = inputSMETiles[index]; auto slice = - rewriter.create(loc, tile, tileSliceIndex); - rewriter.create( + vector::ExtractOp::create(rewriter, loc, tile, tileSliceIndex); + vector::TransferWriteOp::create(rewriter, loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol}, AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), sliceMask, @@ -567,13 +567,13 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks extractOp, "constant vector.create_masks dims should be folded elsewhere"); - auto zero = rewriter.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); auto extractionIndex = getValueOrCreateConstantIndexOp( rewriter, loc, extractOp.getMixedPosition()[0]); - auto extractionInTrueRegion = rewriter.create( + auto extractionInTrueRegion = arith::CmpIOp::create(rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, frontMaskDim); - auto newMaskFrontDim = rewriter.create( + auto newMaskFrontDim = arith::SelectOp::create(rewriter, loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); rewriter.replaceOpWithNewOp( @@ -660,8 +660,8 @@ struct LiftIllegalVectorTransposeToMemory illegalRead, "expected read to have identity permutation map"); auto loc = transposeOp.getLoc(); - auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); // Create a subview that matches the size of the illegal read vector type. auto readType = illegalRead.getVectorType(); @@ -669,14 +669,14 @@ struct LiftIllegalVectorTransposeToMemory llvm::zip_equal(readType.getShape(), readType.getScalableDims()), [&](auto dim) -> Value { auto [size, isScalable] = dim; - auto dimSize = rewriter.create(loc, size); + auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size); if (!isScalable) return dimSize; - auto vscale = rewriter.create(loc); - return rewriter.create(loc, vscale, dimSize); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); + return arith::MulIOp::create(rewriter, loc, vscale, dimSize); }); SmallVector strides(readType.getRank(), Value(one)); - auto readSubview = rewriter.create( + auto readSubview = memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes, strides); @@ -686,13 +686,13 @@ struct LiftIllegalVectorTransposeToMemory if (mask) { // Note: The transpose for the mask should fold into the // vector.create_mask/constant_mask op, which will then become legal. - mask = rewriter.create(loc, mask, + mask = vector::TransposeOp::create(rewriter, loc, mask, transposeOp.getPermutation()); } // - The source memref mlir::AffineMap transposeMap = AffineMap::getPermutationMap( transposeOp.getPermutation(), getContext()); - auto transposedSubview = rewriter.create( + auto transposedSubview = memref::TransposeOp::create(rewriter, loc, readSubview, AffineMapAttr::get(transposeMap)); ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); // - The `in_bounds` attribute @@ -706,7 +706,7 @@ struct LiftIllegalVectorTransposeToMemory VectorType legalReadType = resultType.clone(readType.getElementType()); // Note: The indices are all zero as the subview is already offset. SmallVector readIndices(illegalRead.getIndices().size(), zero); - auto legalRead = rewriter.create( + auto legalRead = vector::TransferReadOp::create(rewriter, loc, legalReadType, transposedSubview, readIndices, illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, inBoundsAttr); @@ -797,12 +797,12 @@ struct LowerIllegalTransposeStoreViaZA AffineMap::getPermutationMap(ArrayRef{1, 0}, getContext())); // Note: We need to use `get_tile` as there's no vector-level `undef`. - Value undefTile = rewriter.create(loc, smeTileType); + Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType); Value destTensorOrMemref = writeOp.getBase(); auto numSlicesPerTile = std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); auto numSlices = - rewriter.create(loc, numSlicesPerTile); + arith::ConstantIndexOp::create(rewriter, loc, numSlicesPerTile); for (auto [index, smeTile] : llvm::enumerate( decomposeToSMETiles(rewriter, sourceType, smeTileType))) { // 1. _Deliberately_ drop a scalable dimension and insert a fixed number @@ -811,46 +811,46 @@ struct LowerIllegalTransposeStoreViaZA // rows of the tile after 1*vscale rows. Value tile = undefTile; for (int d = 0; d < numSlicesPerTile; ++d) { - Value vector = rewriter.create( + Value vector = vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(), rewriter.getIndexAttr(d + smeTile.row)); if (vector.getType() != smeSliceType) { - vector = rewriter.create( + vector = vector::ScalableExtractOp::create(rewriter, loc, smeSliceType, vector, smeTile.col); } - tile = rewriter.create(loc, vector, tile, d); + tile = vector::InsertOp::create(rewriter, loc, vector, tile, d); } // 2. Transpose the tile position. auto transposedRow = createVscaleMultiple(smeTile.col); auto transposedCol = - rewriter.create(loc, smeTile.row); + arith::ConstantIndexOp::create(rewriter, loc, smeTile.row); // 3. Compute mask for tile store. Value maskRows; Value maskCols; if (auto mask = writeOp.getMask()) { auto createMask = mask.getDefiningOp(); - maskRows = rewriter.create(loc, createMask.getOperand(0), + maskRows = arith::SubIOp::create(rewriter, loc, createMask.getOperand(0), transposedRow); - maskCols = rewriter.create(loc, createMask.getOperand(1), + maskCols = arith::SubIOp::create(rewriter, loc, createMask.getOperand(1), transposedCol); - maskCols = rewriter.create(loc, maskCols, numSlices); + maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices); } else { maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); maskCols = numSlices; } - auto subMask = rewriter.create( + auto subMask = vector::CreateMaskOp::create(rewriter, loc, smeTileType.clone(rewriter.getI1Type()), ValueRange{maskRows, maskCols}); // 4. Emit a transposed tile write. auto writeIndices = writeOp.getIndices(); Value destRow = - rewriter.create(loc, transposedRow, writeIndices[0]); + arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]); Value destCol = - rewriter.create(loc, transposedCol, writeIndices[1]); - auto smeWrite = rewriter.create( + arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]); + auto smeWrite = vector::TransferWriteOp::create(rewriter, loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, transposeMap, subMask, writeOp.getInBounds()); @@ -934,41 +934,41 @@ struct LowerColumnTransferReadToLoops // Create a loop over all rows and load one element at a time. auto loc = readOp.getLoc(); - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto createVscaleMultiple = vector::makeVscaleConstantBuilder(rewriter, loc); auto upperBound = createVscaleMultiple(numRows); - auto step = rewriter.create(loc, 1); - Value init = rewriter.create( + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value init = arith::ConstantOp::create(rewriter, loc, newResType, DenseElementsAttr::get(newResType, 0.0f)); scf::ForOp loadLoop; { OpBuilder::InsertionGuard g(rewriter); - loadLoop = rewriter.create(loc, lowerBound, upperBound, step, + loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, ValueRange{init}); rewriter.setInsertionPointToStart(loadLoop.getBody()); auto tileSliceIndex = loadLoop.getInductionVar(); - auto idx0 = rewriter.create(loc, tileSliceIndex, + auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex, readOp.getIndices()[0]); auto idx1 = readOp.getIndices()[1]; - Value scalar = rewriter.create( + Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(), SmallVector({idx0, idx1})); - Operation *updateInit = rewriter.create( + Operation *updateInit = vector::InsertOp::create(rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex); - rewriter.create(loc, updateInit->getResult(0)); + scf::YieldOp::create(rewriter, loc, updateInit->getResult(0)); } // The read operation has been "legalized", but since the original result // type was a 2D vector, we need to cast before returning the result. This // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a // no-op). - auto sc = rewriter.create( + auto sc = vector::ShapeCastOp::create(rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0)); rewriter.replaceOp(readOp, sc); diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp index 594c9b4c270f2..bd7130b667b6b 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp index b2ca4fc1eaa8c..b50679239a840 100644 --- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" using namespace mlir; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 006332b48325f..20f2ad1c3dcec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -89,7 +89,7 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { VectorType sourceType = source.getType(); VectorType resultType = convertOp.getResult().getType(); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resultType, rewriter.getZeroAttr(resultType)); // We want to iterate over the input vector in steps of the trailing @@ -102,14 +102,14 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern { for (SmallVector index : StaticTileOffsetRange(sourceType.getShape(), tileShape)) { auto extractOrInsertPosition = ArrayRef(index).drop_back(); - auto sourceVector = rewriter.create( + auto sourceVector = vector::ExtractOp::create(rewriter, loc, source, extractOrInsertPosition); VectorType convertedType = VectorType::Builder(llvm::cast(sourceVector.getType())) .setDim(0, resultType.getShape().back()); auto convertedVector = - rewriter.create(loc, TypeRange{convertedType}, sourceVector); - result = rewriter.create(loc, convertedVector, result, + IntrOp::create(rewriter, loc, TypeRange{convertedType}, sourceVector); + result = vector::InsertOp::create(rewriter, loc, convertedVector, result, extractOrInsertPosition); } @@ -137,11 +137,11 @@ struct PselOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); auto loc = pselOp.getLoc(); - auto svboolP1 = rewriter.create(loc, svboolType, + auto svboolP1 = ConvertToSvboolIntrOp::create(rewriter, loc, svboolType, adaptor.getP1()); - auto indexI32 = rewriter.create( + auto indexI32 = arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), pselOp.getIndex()); - auto pselIntr = rewriter.create(loc, svboolType, svboolP1, + auto pselIntr = PselIntrOp::create(rewriter, loc, svboolType, svboolP1, pselOp.getP2(), indexI32); rewriter.replaceOpWithNewOp( pselOp, adaptor.getP1().getType(), pselIntr); @@ -176,7 +176,7 @@ struct CreateMaskOpLowering "not SVE predicate-sized"); auto loc = createMaskOp.getLoc(); - auto zero = rewriter.create(loc, rewriter.getI64Type()); + auto zero = LLVM::ZeroOp::create(rewriter, loc, rewriter.getI64Type()); rewriter.replaceOpWithNewOp(createMaskOp, maskType, zero, adaptor.getOperands()[0]); return success(); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index 3dbb93b8a0669..33de12ca17cfc 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -71,7 +71,7 @@ void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op, TLegalizerCallback callback) { replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) { // Mark our `unrealized_conversion_casts` with a pass label. - return rewriter.create( + return UnrealizedConversionCastOp::create(rewriter, op.getLoc(), TypeRange{op.getResult().getType()}, ValueRange{callback(newOp)}, NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag), @@ -239,7 +239,7 @@ struct LegalizeSVEMaskStoreConversion auto legalMaskType = widenScalableMaskTypeToSvbool( llvm::cast(valueToStore.getType())); - auto convertToSvbool = rewriter.create( + auto convertToSvbool = arm_sve::ConvertToSvboolOp::create(rewriter, loc, legalMaskType, valueToStore); // Replace this store with a conversion to a storable svbool mask [1], // followed by a wider store. @@ -290,7 +290,7 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) { newLoadOp.setMemRef(*legalMemref); newLoadOp.getResult().setType(legalMaskType); - return rewriter.create( + return arm_sve::ConvertFromSvboolOp::create(rewriter, loc, loadedMask.getType(), newLoadOp); }); @@ -408,7 +408,7 @@ struct LegalizeTransferRead : public OpRewritePattern { reassoc.back().push_back(i); if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc)) return failure(); - Value collapsedMem = rewriter.create( + Value collapsedMem = memref::CollapseShapeOp::create(rewriter, readOp.getLoc(), readOp.getBase(), reassoc); // Get a vector type with collapsed trailing dimensions. @@ -424,13 +424,13 @@ struct LegalizeTransferRead : public OpRewritePattern { auto indices = readOp.getIndices().drop_back(numCollapseDims - 1); // Create the new `transfer_read`. - auto newReadOp = rewriter.create( + auto newReadOp = vector::TransferReadOp::create(rewriter, readOp.getLoc(), collapsedVT, collapsedMem, indices, readOp.getPadding(), ArrayRef(origInBounds).drop_back(numCollapseDims - 1)); // Cast back to the original vector type. - auto toOrigShape = rewriter.create(readOp.getLoc(), + auto toOrigShape = vector::ShapeCastOp::create(rewriter, readOp.getLoc(), origVT, newReadOp); rewriter.replaceOp(readOp, toOrigShape); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp index b7703ff0393eb..959c5d7291c31 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -90,15 +90,15 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, mlir::VectorType accType, Value acc, Value lhs, Value rhs) { switch (op) { case MMLA::Signed: - return rewriter.create(loc, accType, acc, lhs, rhs); + return arm_sve::SmmlaOp::create(rewriter, loc, accType, acc, lhs, rhs); case MMLA::Unsigned: - return rewriter.create(loc, accType, acc, lhs, rhs); + return arm_sve::UmmlaOp::create(rewriter, loc, accType, acc, lhs, rhs); case MMLA::Mixed: - return rewriter.create(loc, accType, acc, lhs, rhs); + return arm_sve::UsmmlaOp::create(rewriter, loc, accType, acc, lhs, rhs); case MMLA::MixedSwapped: // The accumulator comes transposed and the result will be transposed // later, so all we have to do here is swap the operands. - return rewriter.create(loc, accType, acc, rhs, lhs); + return arm_sve::UsmmlaOp::create(rewriter, loc, accType, acc, rhs, lhs); } } @@ -236,25 +236,25 @@ class LowerContractionToSVEI8MMPattern SmallVector lhsTile; for (int64_t i = 0; i < M; i += 2) { // Extract two consecutive rows of the LHS tile. - auto r0 = rewriter.create(loc, *maybeLhs, + auto r0 = vector::ExtractOp::create(rewriter, loc, *maybeLhs, ArrayRef{i}); - auto r1 = rewriter.create(loc, *maybeLhs, + auto r1 = vector::ExtractOp::create(rewriter, loc, *maybeLhs, ArrayRef{i + 1}); // Concatenate to obtain a 16 x i8 flattened sub-tile. - auto t = rewriter.create( + auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, llvm::ArrayRef{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); // Turn it into a scalable vector. - auto s = rewriter.create( - loc, t, rewriter.create(loc, nxv16i8), 0); + auto s = vector::ScalableInsertOp::create(rewriter, + loc, t, ub::PoisonOp::create(rewriter, loc, nxv16i8), 0); // Replicate the sub-tile VSCALE times to fill the entire vector. - auto r = rewriter.create(loc, s, 0); + auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0); lhsTile.push_back(r); } // "Flatten" the RHS tile from <[N]x8> to <[8*N]>. - auto rhs = rewriter.create( + auto rhs = vector::ShapeCastOp::create(rewriter, maybeRhs->getLoc(), VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(), /*scalableDims=*/{true}), @@ -264,7 +264,7 @@ class LowerContractionToSVEI8MMPattern SmallVector rhsTile; for (int64_t j = 0; j < N; j += 2) rhsTile.push_back( - rewriter.create(loc, nxv16i8, rhs, j * 8)); + vector::ScalableExtractOp::create(rewriter, loc, nxv16i8, rhs, j * 8)); // Handy types for packing/unpacking of the accumulator tile. auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(), @@ -280,33 +280,33 @@ class LowerContractionToSVEI8MMPattern SmallVector accTile; for (int64_t i = 0; i < M; i += 2) { // Extract two consecutive rows of the accumulator tile. - auto r0 = rewriter.create(loc, op.getAcc(), + auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), ArrayRef{i}); - auto r1 = rewriter.create(loc, op.getAcc(), + auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), ArrayRef{i + 1}); Value accTileVec; if (mmlaOp == MMLA::MixedSwapped) { // We need to swap the positions of the LHS and RHS (since we don't have // a signed * unsigned operation), but then each individual 2x2 tile of // the acumulator and (later) the result need to be transposed. - accTileVec = rewriter.create(loc, r0, r1); + accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1); } else { // Bitcast them to 64-bit elements, so subsequent // interleave/deinterleave work on pairs of 32-bit numbers. - auto r0I64 = rewriter.create(loc, accRow64Ty, r0); - auto r1I64 = rewriter.create(loc, accRow64Ty, r1); + auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0); + auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1); // Interleave the rows, effectively flattening each 2x2 tile into 4 // consecutive elements. - auto intrI64 = rewriter.create(loc, r0I64, r1I64); + auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64); // Bitcast back to 32-bit elements. accTileVec = - rewriter.create(loc, accRowX2Ty, intrI64); + vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64); } // Extract ACC sub-tiles. for (int64_t j = 0; j < N; j += 2) - accTile.push_back(rewriter.create( + accTile.push_back(vector::ScalableExtractOp::create(rewriter, loc, nxv4i32, accTileVec, j * 2)); } @@ -320,12 +320,12 @@ class LowerContractionToSVEI8MMPattern } // Unpack the OUT sub-tiles and insert into the result. - Value result = rewriter.create(loc, op.getResultType()); + Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType()); for (int64_t i = 0; i < M / 2; ++i) { // Collect a number of sub-tiles in a row. - Value row = rewriter.create(loc, accRowX2Ty); + Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty); for (int64_t j = 0; j < N / 2; ++j) - row = rewriter.create( + row = vector::ScalableInsertOp::create(rewriter, loc, outTile[i * N / 2 + j], row, j * 4); // Unpack the row to obtain two rows of the output. If we have the out @@ -334,22 +334,22 @@ class LowerContractionToSVEI8MMPattern // Otherwise, the interleave is by pairs. Value out0, out1; if (mmlaOp == MMLA::MixedSwapped) { - auto tmp = rewriter.create(loc, row); + auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row); out0 = tmp.getRes1(); out1 = tmp.getRes2(); } else { // Deinterleave by pairs. - auto row64 = rewriter.create(loc, accRowX264Ty, row); - auto deintr64 = rewriter.create(loc, row64); + auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row); + auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64); // Bitcast back into 32-bit elements and insert into the result. - out0 = rewriter.create(loc, accRowTy, + out0 = vector::BitCastOp::create(rewriter, loc, accRowTy, deintr64.getRes1()); - out1 = rewriter.create(loc, accRowTy, + out1 = vector::BitCastOp::create(rewriter, loc, accRowTy, deintr64.getRes2()); } - result = rewriter.create(loc, out0, result, i * 2); - result = rewriter.create(loc, out1, result, i * 2 + 1); + result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2); + result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index b834afef7da79..e7d95575c6edf 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -99,7 +100,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result, // expected result is empty. Otherwise, leave this to the caller // because we don't know which values to return from the execute op. if (resultTypes.empty() && !bodyBuilder) { - builder.create(result.location, ValueRange()); + async::YieldOp::create(builder, result.location, ValueRange()); } else if (bodyBuilder) { bodyBuilder(builder, result.location, bodyBlock->getArguments()); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 27fa92cee79c2..09367ab544ff8 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -191,8 +191,8 @@ static SmallVector delinearize(ImplicitLocOpBuilder &b, Value index, assert(!tripCounts.empty() && "tripCounts must be not empty"); for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) { - coords[i] = b.create(index, tripCounts[i]); - index = b.create(index, tripCounts[i]); + coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]); + index = arith::DivSIOp::create(b, index, tripCounts[i]); } return coords; @@ -276,15 +276,15 @@ static ParallelComputeFunction createParallelComputeFunction( BlockArgument blockSize = args.blockSize(); // Constants used below. - Value c0 = b.create(0); - Value c1 = b.create(1); + Value c0 = arith::ConstantIndexOp::create(b, 0); + Value c1 = arith::ConstantIndexOp::create(b, 1); // Materialize known constants as constant operation in the function body. auto values = [&](ArrayRef args, ArrayRef attrs) { return llvm::to_vector( llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { if (IntegerAttr attr = std::get<1>(tuple)) - return b.create(attr); + return arith::ConstantOp::create(b, attr); return std::get<0>(tuple); })); }; @@ -303,17 +303,17 @@ static ParallelComputeFunction createParallelComputeFunction( // one-dimensional iteration space. Value tripCount = tripCounts[0]; for (unsigned i = 1; i < tripCounts.size(); ++i) - tripCount = b.create(tripCount, tripCounts[i]); + tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]); // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: // blockFirstIndex = blockIndex * blockSize - Value blockFirstIndex = b.create(blockIndex, blockSize); + Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize); // The last one-dimensional index in the block defined by the `blockIndex`: // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1 - Value blockEnd0 = b.create(blockFirstIndex, blockSize); - Value blockEnd1 = b.create(blockEnd0, tripCount); - Value blockLastIndex = b.create(blockEnd1, c1); + Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize); + Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount); + Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1); // Convert one-dimensional indices to multi-dimensional coordinates. auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); @@ -326,7 +326,7 @@ static ParallelComputeFunction createParallelComputeFunction( // dimension when inner compute dimension contains multiple blocks. SmallVector blockEndCoord(op.getNumLoops()); for (size_t i = 0; i < blockLastCoord.size(); ++i) - blockEndCoord[i] = b.create(blockLastCoord[i], c1); + blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1); // Construct a loop nest out of scf.for operations that will iterate over // all coordinates in [blockFirstCoord, blockLastCoord] range. @@ -369,21 +369,22 @@ static ParallelComputeFunction createParallelComputeFunction( ImplicitLocOpBuilder b(loc, nestedBuilder); // Compute induction variable for `loopIdx`. - computeBlockInductionVars[loopIdx] = b.create( - lowerBounds[loopIdx], b.create(iv, steps[loopIdx])); + computeBlockInductionVars[loopIdx] = + arith::AddIOp::create(b, lowerBounds[loopIdx], + arith::MulIOp::create(b, iv, steps[loopIdx])); // Check if we are inside first or last iteration of the loop. - isBlockFirstCoord[loopIdx] = b.create( - arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); - isBlockLastCoord[loopIdx] = b.create( - arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); + isBlockFirstCoord[loopIdx] = arith::CmpIOp::create( + b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); + isBlockLastCoord[loopIdx] = arith::CmpIOp::create( + b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); // Check if the previous loop is in its first or last iteration. if (loopIdx > 0) { - isBlockFirstCoord[loopIdx] = b.create( - isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); - isBlockLastCoord[loopIdx] = b.create( - isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); + isBlockFirstCoord[loopIdx] = arith::AndIOp::create( + b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); + isBlockLastCoord[loopIdx] = arith::AndIOp::create( + b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); } // Keep building loop nest. @@ -391,24 +392,24 @@ static ParallelComputeFunction createParallelComputeFunction( if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) { // For block aligned loops we always iterate starting from 0 up to // the loop trip counts. - b.create(c0, tripCounts[loopIdx + 1], c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); + scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); } else { // Select nested loop lower/upper bounds depending on our position in // the multi-dimensional iteration space. - auto lb = b.create(isBlockFirstCoord[loopIdx], - blockFirstCoord[loopIdx + 1], c0); + auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx], + blockFirstCoord[loopIdx + 1], c0); - auto ub = b.create(isBlockLastCoord[loopIdx], - blockEndCoord[loopIdx + 1], - tripCounts[loopIdx + 1]); + auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx], + blockEndCoord[loopIdx + 1], + tripCounts[loopIdx + 1]); - b.create(lb, ub, c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); + scf::ForOp::create(b, lb, ub, c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); } - b.create(loc); + scf::YieldOp::create(b, loc); return; } @@ -419,13 +420,13 @@ static ParallelComputeFunction createParallelComputeFunction( for (auto &bodyOp : op.getRegion().front().without_terminator()) b.clone(bodyOp, mapping); - b.create(loc); + scf::YieldOp::create(b, loc); }; }; - b.create(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), - workLoopBuilder(0)); - b.create(ValueRange()); + scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), + workLoopBuilder(0)); + func::ReturnOp::create(b, ValueRange()); return {op.getNumLoops(), func, std::move(computeFuncType.captures)}; } @@ -485,8 +486,8 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, b.setInsertionPointToEnd(block); Type indexTy = b.getIndexType(); - Value c1 = b.create(1); - Value c2 = b.create(2); + Value c1 = arith::ConstantIndexOp::create(b, 1); + Value c2 = arith::ConstantIndexOp::create(b, 2); // Get the async group that will track async dispatch completion. Value group = block->getArgument(0); @@ -501,7 +502,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, SmallVector locations = {loc, loc}; // Create a recursive dispatch loop. - scf::WhileOp whileOp = b.create(types, operands); + scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands); Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations); Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations); @@ -511,10 +512,10 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, b.setInsertionPointToEnd(before); Value start = before->getArgument(0); Value end = before->getArgument(1); - Value distance = b.create(end, start); + Value distance = arith::SubIOp::create(b, end, start); Value dispatch = - b.create(arith::CmpIPredicate::sgt, distance, c1); - b.create(dispatch, before->getArguments()); + arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1); + scf::ConditionOp::create(b, dispatch, before->getArguments()); } // Setup the async dispatch loop body: recursively call dispatch function @@ -523,9 +524,9 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, b.setInsertionPointToEnd(after); Value start = after->getArgument(0); Value end = after->getArgument(1); - Value distance = b.create(end, start); - Value halfDistance = b.create(distance, c2); - Value midIndex = b.create(start, halfDistance); + Value distance = arith::SubIOp::create(b, end, start); + Value halfDistance = arith::DivSIOp::create(b, distance, c2); + Value midIndex = arith::AddIOp::create(b, start, halfDistance); // Call parallel compute function inside the async.execute region. auto executeBodyBuilder = [&](OpBuilder &executeBuilder, @@ -536,16 +537,16 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, operands[1] = midIndex; operands[2] = end; - executeBuilder.create(executeLoc, func.getSymName(), - func.getResultTypes(), operands); - executeBuilder.create(executeLoc, ValueRange()); + func::CallOp::create(executeBuilder, executeLoc, func.getSymName(), + func.getResultTypes(), operands); + async::YieldOp::create(executeBuilder, executeLoc, ValueRange()); }; // Create async.execute operation to dispatch half of the block range. - auto execute = b.create(TypeRange(), ValueRange(), ValueRange(), - executeBodyBuilder); - b.create(indexTy, execute.getToken(), group); - b.create(ValueRange({start, midIndex})); + auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(), + executeBodyBuilder); + AddToGroupOp::create(b, indexTy, execute.getToken(), group); + scf::YieldOp::create(b, ValueRange({start, midIndex})); } // After dispatching async operations to process the tail of the block range @@ -557,10 +558,9 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, SmallVector computeFuncOperands = {blockStart}; computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end()); - b.create(computeFunc.func.getSymName(), - computeFunc.func.getResultTypes(), - computeFuncOperands); - b.create(ValueRange()); + func::CallOp::create(b, computeFunc.func.getSymName(), + computeFunc.func.getResultTypes(), computeFuncOperands); + func::ReturnOp::create(b, ValueRange()); return func; } @@ -578,8 +578,8 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, func::FuncOp asyncDispatchFunction = createAsyncDispatchFunction(parallelComputeFunction, rewriter); - Value c0 = b.create(0); - Value c1 = b.create(1); + Value c0 = arith::ConstantIndexOp::create(b, 0); + Value c1 = arith::ConstantIndexOp::create(b, 1); // Appends operands shared by async dispatch and parallel compute functions to // the given operands vector. @@ -595,7 +595,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, // completely. If this will be known statically, then canonicalization will // erase async group operations. Value isSingleBlock = - b.create(arith::CmpIPredicate::eq, blockCount, c1); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1); auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { ImplicitLocOpBuilder b(loc, nestedBuilder); @@ -604,10 +604,10 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, SmallVector operands = {c0, blockSize}; appendBlockComputeOperands(operands); - b.create(parallelComputeFunction.func.getSymName(), - parallelComputeFunction.func.getResultTypes(), - operands); - b.create(); + func::CallOp::create(b, parallelComputeFunction.func.getSymName(), + parallelComputeFunction.func.getResultTypes(), + operands); + scf::YieldOp::create(b); }; auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { @@ -616,24 +616,24 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, // Create an async.group to wait on all async tokens from the concurrent // execution of multiple parallel compute function. First block will be // executed synchronously in the caller thread. - Value groupSize = b.create(blockCount, c1); - Value group = b.create(GroupType::get(ctx), groupSize); + Value groupSize = arith::SubIOp::create(b, blockCount, c1); + Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize); // Launch async dispatch function for [0, blockCount) range. SmallVector operands = {group, c0, blockCount, blockSize}; appendBlockComputeOperands(operands); - b.create(asyncDispatchFunction.getSymName(), - asyncDispatchFunction.getResultTypes(), operands); + func::CallOp::create(b, asyncDispatchFunction.getSymName(), + asyncDispatchFunction.getResultTypes(), operands); // Wait for the completion of all parallel compute operations. - b.create(group); + AwaitAllOp::create(b, group); - b.create(); + scf::YieldOp::create(b); }; // Dispatch either single block compute function, or launch async dispatch. - b.create(isSingleBlock, syncDispatch, asyncDispatch); + scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch); } // Dispatch parallel compute functions by submitting all async compute tasks @@ -647,14 +647,14 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, func::FuncOp compute = parallelComputeFunction.func; - Value c0 = b.create(0); - Value c1 = b.create(1); + Value c0 = arith::ConstantIndexOp::create(b, 0); + Value c1 = arith::ConstantIndexOp::create(b, 1); // Create an async.group to wait on all async tokens from the concurrent // execution of multiple parallel compute function. First block will be // executed synchronously in the caller thread. - Value groupSize = b.create(blockCount, c1); - Value group = b.create(GroupType::get(ctx), groupSize); + Value groupSize = arith::SubIOp::create(b, blockCount, c1); + Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize); // Call parallel compute function for all blocks. using LoopBodyBuilder = @@ -681,28 +681,27 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, // Call parallel compute function inside the async.execute region. auto executeBodyBuilder = [&](OpBuilder &executeBuilder, Location executeLoc, ValueRange executeArgs) { - executeBuilder.create(executeLoc, compute.getSymName(), - compute.getResultTypes(), - computeFuncOperands(iv)); - executeBuilder.create(executeLoc, ValueRange()); + func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(), + compute.getResultTypes(), computeFuncOperands(iv)); + async::YieldOp::create(executeBuilder, executeLoc, ValueRange()); }; // Create async.execute operation to launch parallel computate function. - auto execute = b.create(TypeRange(), ValueRange(), ValueRange(), - executeBodyBuilder); - b.create(rewriter.getIndexType(), execute.getToken(), group); - b.create(); + auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(), + executeBodyBuilder); + AddToGroupOp::create(b, rewriter.getIndexType(), execute.getToken(), group); + scf::YieldOp::create(b); }; // Iterate over all compute blocks and launch parallel compute operations. - b.create(c1, blockCount, c1, ValueRange(), loopBuilder); + scf::ForOp::create(b, c1, blockCount, c1, ValueRange(), loopBuilder); // Call parallel compute function for the first block in the caller thread. - b.create(compute.getSymName(), compute.getResultTypes(), - computeFuncOperands(c0)); + func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(), + computeFuncOperands(c0)); // Wait for the completion of all async compute operations. - b.create(group); + AwaitAllOp::create(b, group); } LogicalResult @@ -738,17 +737,17 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, // for the scf.parallel operation. Value tripCount = tripCounts[0]; for (size_t i = 1; i < tripCounts.size(); ++i) - tripCount = b.create(tripCount, tripCounts[i]); + tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]); // Short circuit no-op parallel loops (zero iterations) that can arise from // the memrefs with dynamic dimension(s) equal to zero. - Value c0 = b.create(0); + Value c0 = arith::ConstantIndexOp::create(b, 0); Value isZeroIterations = - b.create(arith::CmpIPredicate::eq, tripCount, c0); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0); // Do absolutely nothing if the trip count is zero. auto noOp = [&](OpBuilder &nestedBuilder, Location loc) { - nestedBuilder.create(loc); + scf::YieldOp::create(nestedBuilder, loc); }; // Compute the parallel block size and dispatch concurrent tasks computing @@ -798,9 +797,9 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, Value numWorkerThreadsVal; if (numWorkerThreads >= 0) - numWorkerThreadsVal = b.create(numWorkerThreads); + numWorkerThreadsVal = arith::ConstantIndexOp::create(b, numWorkerThreads); else - numWorkerThreadsVal = b.create(); + numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b); // With large number of threads the value of creating many compute blocks // is reduced because the problem typically becomes memory bound. For this @@ -819,38 +818,38 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}}; const float initialOvershardingFactor = 8.0f; - Value scalingFactor = b.create( - b.getF32Type(), llvm::APFloat(initialOvershardingFactor)); + Value scalingFactor = arith::ConstantFloatOp::create( + b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor)); for (const std::pair &p : overshardingBrackets) { - Value bracketBegin = b.create(p.first); - Value inBracket = b.create( - arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); - Value bracketScalingFactor = b.create( - b.getF32Type(), llvm::APFloat(p.second)); - scalingFactor = b.create(inBracket, bracketScalingFactor, - scalingFactor); + Value bracketBegin = arith::ConstantIndexOp::create(b, p.first); + Value inBracket = arith::CmpIOp::create( + b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); + Value bracketScalingFactor = arith::ConstantFloatOp::create( + b, b.getF32Type(), llvm::APFloat(p.second)); + scalingFactor = arith::SelectOp::create( + b, inBracket, bracketScalingFactor, scalingFactor); } Value numWorkersIndex = - b.create(b.getI32Type(), numWorkerThreadsVal); + arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal); Value numWorkersFloat = - b.create(b.getF32Type(), numWorkersIndex); + arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex); Value scaledNumWorkers = - b.create(scalingFactor, numWorkersFloat); + arith::MulFOp::create(b, scalingFactor, numWorkersFloat); Value scaledNumInt = - b.create(b.getI32Type(), scaledNumWorkers); + arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers); Value scaledWorkers = - b.create(b.getIndexType(), scaledNumInt); + arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt); - Value maxComputeBlocks = b.create( - b.create(1), scaledWorkers); + Value maxComputeBlocks = arith::MaxSIOp::create( + b, arith::ConstantIndexOp::create(b, 1), scaledWorkers); // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), // minTaskSize)) - Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = b.create(bs0, minTaskSize); - Value blockSize = b.create(tripCount, bs1); + Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks); + Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize); + Value blockSize = arith::MinSIOp::create(b, tripCount, bs1); // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. @@ -860,7 +859,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, // the parallel operation body for a subset of iteration space. // Compute the number of parallel compute blocks. - Value blockCount = b.create(tripCount, blockSize); + Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize); // Dispatch parallel compute function without hints to unroll inner loops. auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) { @@ -869,7 +868,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, ImplicitLocOpBuilder b(loc, nestedBuilder); doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts); - b.create(); + scf::YieldOp::create(b); }; // Dispatch parallel compute function with hints for unrolling inner loops. @@ -880,34 +879,34 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, ImplicitLocOpBuilder b(loc, nestedBuilder); // Align the block size to be a multiple of the statically known // number of iterations in the inner loops. - Value numIters = b.create( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value alignedBlockSize = b.create( - b.create(blockSize, numIters), numIters); + Value numIters = arith::ConstantIndexOp::create( + b, numIterations[op.getNumLoops() - numUnrollableLoops]); + Value alignedBlockSize = arith::MulIOp::create( + b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters); doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount, tripCounts); - b.create(); + scf::YieldOp::create(b); }; // Dispatch to block aligned compute function only if the computed block // size is larger than the number of iterations in the unrollable inner // loops, because otherwise it can reduce the available parallelism. if (numUnrollableLoops > 0) { - Value numIters = b.create( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value useBlockAlignedComputeFn = b.create( - arith::CmpIPredicate::sge, blockSize, numIters); - - b.create(useBlockAlignedComputeFn, dispatchBlockAligned, - dispatchDefault); - b.create(); + Value numIters = arith::ConstantIndexOp::create( + b, numIterations[op.getNumLoops() - numUnrollableLoops]); + Value useBlockAlignedComputeFn = arith::CmpIOp::create( + b, arith::CmpIPredicate::sge, blockSize, numIters); + + scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned, + dispatchDefault); + scf::YieldOp::create(b); } else { dispatchDefault(b, loc); } }; // Replace the `scf.parallel` operation with the parallel compute function. - b.create(isZeroIterations, noOp, dispatch); + scf::IfOp::create(b, isZeroIterations, noOp, dispatch); // Parallel operation was replaced with a block iteration loop. rewriter.eraseOp(op); @@ -922,7 +921,7 @@ void AsyncParallelForPass::runOnOperation() { populateAsyncParallelForPatterns( patterns, asyncDispatch, numWorkerThreads, [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { - return builder.create(minTaskSize); + return arith::ConstantIndexOp::create(builder, minTaskSize); }); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp index d11cd8444636a..6bcc5849eff1b 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -51,7 +51,7 @@ static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { else b.setInsertionPointToStart(value.getParentBlock()); - b.create(value.getLoc(), value, b.getI64IntegerAttr(1)); + RuntimeDropRefOp::create(b, value.getLoc(), value, b.getI64IntegerAttr(1)); return success(); } @@ -312,7 +312,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { // Add a drop_ref immediately after the last user. builder.setInsertionPointAfter(lastUser); - builder.create(loc, value, builder.getI64IntegerAttr(1)); + RuntimeDropRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1)); } return success(); @@ -330,7 +330,7 @@ AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { // Add a reference before the function call to pass the value at `+1` // reference to the function entry block. builder.setInsertionPoint(user); - builder.create(loc, value, builder.getI64IntegerAttr(1)); + RuntimeAddRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1)); } return success(); @@ -414,11 +414,11 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( refCountingBlock = &successor->getParent()->emplaceBlock(); refCountingBlock->moveBefore(successor); OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); - builder.create(value.getLoc(), successor); + cf::BranchOp::create(builder, value.getLoc(), successor); } OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); - builder.create(value.getLoc(), value, + RuntimeDropRefOp::create(builder, value.getLoc(), value, builder.getI64IntegerAttr(1)); // No need to update the terminator operation. @@ -510,13 +510,13 @@ AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { // Create `add_ref` operation before the operand owner. if (cnt > 0) { b.setInsertionPoint(operand.getOwner()); - b.create(loc, value, b.getI64IntegerAttr(cnt)); + RuntimeAddRefOp::create(b, loc, value, b.getI64IntegerAttr(cnt)); } // Create `drop_ref` operation after the operand owner. if (cnt < 0) { b.setInsertionPointAfter(operand.getOwner()); - b.create(loc, value, b.getI64IntegerAttr(-cnt)); + RuntimeDropRefOp::create(b, loc, value, b.getI64IntegerAttr(-cnt)); } } } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp index 8601bb5aaada9..8b9feb8214923 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -188,22 +188,22 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { std::optional retToken; if (isStateful) - retToken.emplace(builder.create(TokenType::get(ctx))); + retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx))); llvm::SmallVector retValues; ArrayRef resValueTypes = isStateful ? func.getResultTypes().drop_front() : func.getResultTypes(); for (auto resType : resValueTypes) retValues.emplace_back( - builder.create(resType).getResult()); + RuntimeCreateOp::create(builder, resType).getResult()); // ------------------------------------------------------------------------ // // Initialize coroutine: get coroutine id and coroutine handle. // ------------------------------------------------------------------------ // - auto coroIdOp = builder.create(CoroIdType::get(ctx)); + auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx)); auto coroHdlOp = - builder.create(CoroHandleType::get(ctx), coroIdOp.getId()); - builder.create(originalEntryBlock); + CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId()); + cf::BranchOp::create(builder, originalEntryBlock); Block *cleanupBlock = func.addBlock(); Block *cleanupBlockForDestroy = func.addBlock(); @@ -214,10 +214,10 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { // ------------------------------------------------------------------------ // auto buildCleanupBlock = [&](Block *cb) { builder.setInsertionPointToStart(cb); - builder.create(coroIdOp.getId(), coroHdlOp.getHandle()); + CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle()); // Branch into the suspend block. - builder.create(suspendBlock); + cf::BranchOp::create(builder, suspendBlock); }; buildCleanupBlock(cleanupBlock); buildCleanupBlock(cleanupBlockForDestroy); @@ -229,7 +229,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { builder.setInsertionPointToStart(suspendBlock); // Mark the end of a coroutine: async.coro.end - builder.create(coroHdlOp.getHandle()); + CoroEndOp::create(builder, coroHdlOp.getHandle()); // Return created optional `async.token` and `async.values` from the suspend // block. This will be the return value of a coroutine ramp function. @@ -237,7 +237,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { if (retToken) ret.push_back(*retToken); llvm::append_range(ret, retValues); - builder.create(ret); + func::ReturnOp::create(builder, ret); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. @@ -274,13 +274,13 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) { // Coroutine set_error block: set error on token and all returned values. if (coro.asyncToken) - builder.create(*coro.asyncToken); + RuntimeSetErrorOp::create(builder, *coro.asyncToken); for (Value retValue : coro.returnValues) - builder.create(retValue); + RuntimeSetErrorOp::create(builder, retValue); // Branch into the cleanup block. - builder.create(coro.cleanup); + cf::BranchOp::create(builder, coro.cleanup); return *coro.setError; } @@ -335,13 +335,13 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { // Await on all dependencies before starting to execute the body region. for (size_t i = 0; i < numDependencies; ++i) - builder.create(func.getArgument(i)); + AwaitOp::create(builder, func.getArgument(i)); // Await on all async value operands and unwrap the payload. SmallVector unwrappedOperands(numOperands); for (size_t i = 0; i < numOperands; ++i) { Value operand = func.getArgument(numDependencies + i); - unwrappedOperands[i] = builder.create(loc, operand).getResult(); + unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult(); } // Map from function inputs defined above the execute op to the function @@ -368,14 +368,14 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { // Save the coroutine state: async.coro.save auto coroSaveOp = - builder.create(CoroStateType::get(ctx), coro.coroHandle); + CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle); // Pass coroutine to the runtime to be resumed on a runtime managed // thread. - builder.create(coro.coroHandle); + RuntimeResumeOp::create(builder, coro.coroHandle); // Add async.coro.suspend as a suspended block terminator. - builder.create(coroSaveOp.getState(), coro.suspend, + CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend, branch.getDest(), coro.cleanupForDestroy); branch.erase(); @@ -384,7 +384,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { // Replace the original `async.execute` with a call to outlined function. { ImplicitLocOpBuilder callBuilder(loc, execute); - auto callOutlinedFunc = callBuilder.create( + auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); @@ -453,7 +453,7 @@ class AsyncFuncOpLowering : public OpConversionPattern { Location loc = op->getLoc(); auto newFuncOp = - rewriter.create(loc, op.getName(), op.getFunctionType()); + func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType()); SymbolTable::setSymbolVisibility(newFuncOp, SymbolTable::getSymbolVisibility(op)); @@ -523,16 +523,16 @@ class AsyncReturnOpLowering : public OpConversionPattern { for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value returnValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); - rewriter.create(loc, returnValue, asyncValue); - rewriter.create(loc, asyncValue); + RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue); + RuntimeSetAvailableOp::create(rewriter, loc, asyncValue); } if (coro.asyncToken) // Switch the coroutine completion token to available state. - rewriter.create(loc, *coro.asyncToken); + RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken); rewriter.eraseOp(op); - rewriter.create(loc, coro.cleanup); + cf::BranchOp::create(rewriter, loc, coro.cleanup); return success(); } @@ -583,15 +583,15 @@ class AwaitOpLoweringBase : public OpConversionPattern { // the async object (token, value or group) to become available. if (!isInCoroutine) { ImplicitLocOpBuilder builder(loc, rewriter); - builder.create(loc, operand); + RuntimeAwaitOp::create(builder, loc, operand); // Assert that the awaited operands is not in the error state. - Value isError = builder.create(i1, operand); - Value notError = builder.create( - isError, builder.create( + Value isError = RuntimeIsErrorOp::create(builder, i1, operand); + Value notError = arith::XOrIOp::create(builder, + isError, arith::ConstantOp::create(builder, loc, i1, builder.getIntegerAttr(i1, 1))); - builder.create(notError, + cf::AssertOp::create(builder, notError, "Awaited async operand is in error state"); } @@ -607,15 +607,15 @@ class AwaitOpLoweringBase : public OpConversionPattern { // Save the coroutine state and resume on a runtime managed thread when // the operand becomes available. auto coroSaveOp = - builder.create(CoroStateType::get(ctx), coro.coroHandle); - builder.create(operand, coro.coroHandle); + CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle); + RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle); // Split the entry block before the await operation. Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); // Add async.coro.suspend as a suspended block terminator. builder.setInsertionPointToEnd(suspended); - builder.create(coroSaveOp.getState(), coro.suspend, resume, + CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend, resume, coro.cleanupForDestroy); // Split the resume block into error checking and continuation. @@ -623,8 +623,8 @@ class AwaitOpLoweringBase : public OpConversionPattern { // Check if the awaited value is in the error state. builder.setInsertionPointToStart(resume); - auto isError = builder.create(loc, i1, operand); - builder.create(isError, + auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand); + cf::CondBranchOp::create(builder, isError, /*trueDest=*/setupSetErrorBlock(coro), /*trueArgs=*/ArrayRef(), /*falseDest=*/continuation, @@ -674,7 +674,7 @@ class AwaitValueOpLowering : public AwaitOpLoweringBase { ConversionPatternRewriter &rewriter) const override { // Load from the async value storage. auto valueType = cast(operand.getType()).getValueType(); - return rewriter.create(op->getLoc(), valueType, operand); + return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand); } }; @@ -715,15 +715,15 @@ class YieldOpLowering : public OpConversionPattern { for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); - rewriter.create(loc, yieldValue, asyncValue); - rewriter.create(loc, asyncValue); + RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue); + RuntimeSetAvailableOp::create(rewriter, loc, asyncValue); } if (coro.asyncToken) // Switch the coroutine completion token to available state. - rewriter.create(loc, *coro.asyncToken); + RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken); - rewriter.create(loc, coro.cleanup); + cf::BranchOp::create(rewriter, loc, coro.cleanup); rewriter.eraseOp(op); return success(); @@ -757,7 +757,7 @@ class AssertOpLowering : public OpConversionPattern { Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, adaptor.getArg(), + cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp index ca914df8b7890..59ebdbfdc727d 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp @@ -36,7 +36,7 @@ using namespace bufferization; //===----------------------------------------------------------------------===// static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { - return builder.create(loc, builder.getBoolAttr(value)); + return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value)); } static bool isMemref(Value v) { return isa(v.getType()); } @@ -151,7 +151,7 @@ DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, // ownerships more intelligently to not end up with an 'Unknown' ownership in // the first place. auto cloneOp = - builder.create(memref.getLoc(), memref); + bufferization::CloneOp::create(builder, memref.getLoc(), memref); Value condition = buildBoolValue(builder, memref.getLoc(), true); Value newMemref = cloneOp.getResult(); updateOwnership(newMemref, condition); @@ -197,7 +197,7 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such // that we can call extract_strided_metadata on it. if (auto unrankedMemRefTy = dyn_cast(memref.getType())) - memref = builder.create( + memref = memref::ReinterpretCastOp::create(builder, loc, memref, /*offset=*/builder.getIndexAttr(0), /*sizes=*/ArrayRef{}, @@ -208,7 +208,7 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( // alloc operation has to be passed to the dealloc operation. Passing // subviews, etc. to a dealloc operation is not allowed. memrefs.push_back( - builder.create(loc, memref) + memref::ExtractStridedMetadataOp::create(builder, loc, memref) .getResult(0)); conditions.push_back(ownership.getIndicator()); } @@ -297,7 +297,7 @@ FailureOr deallocation_impl::insertDeallocOpForReturnLike( if (memrefs.empty() && toRetain.empty()) return op; - auto deallocOp = builder.create( + auto deallocOp = bufferization::DeallocOp::create(builder, op->getLoc(), memrefs, conditions, toRetain); // We want to replace the current ownership of the retained values with the diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 8f17a82fabe03..05a1367a6b126 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -170,7 +170,7 @@ FailureOr bufferization::allocateTensorForShapedValue( if (llvm::isa(shapedValue.getType())) { tensor = shapedValue; } else if (llvm::isa(shapedValue.getType())) { - tensor = b.create( + tensor = ToTensorOp::create(b, loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()), shapedValue); } else if (llvm::isa(shapedValue.getType()) || @@ -209,7 +209,7 @@ FailureOr bufferization::allocateTensorForShapedValue( } // Create AllocTensorOp. - auto allocTensorOp = b.create(loc, tensorType, dynamicSizes, + auto allocTensorOp = AllocTensorOp::create(b, loc, tensorType, dynamicSizes, copy ? tensor : Value()); // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. @@ -753,7 +753,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually // loose all of its users and eventually DCE away. rewriter.setInsertionPointAfter(op); - replacement = rewriter.create( + replacement = bufferization::ToTensorOp::create(rewriter, replacement.getLoc(), opResult.getType(), replacement); } replacements.push_back(replacement); @@ -779,7 +779,7 @@ FailureOr BufferizationOptions::createAlloc(OpBuilder &b, Location loc, .create(loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)) .getResult(); - return b.create(loc, type, dynShape).getResult(); + return memref::AllocOp::create(b, loc, type, dynShape).getResult(); } /// Create a memory copy between two memref buffers. @@ -788,7 +788,7 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, if (memCpyFn) return (*memCpyFn)(b, loc, from, to); - b.create(loc, from, to); + memref::CopyOp::create(b, loc, from, to); return success(); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 5c1d42db18c47..b4910d45e1ae6 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include @@ -59,7 +60,7 @@ FailureOr mlir::bufferization::castOrReallocMemRefValue( // a fix extra conditions in `isGuaranteedCastCompatible`. if (memref::CastOp::areCastCompatible(srcType, destType) && isGuaranteedCastCompatible(srcType, destType)) { - Value casted = b.create(value.getLoc(), destType, value); + Value casted = memref::CastOp::create(b, value.getLoc(), destType, value); return casted; } @@ -68,7 +69,7 @@ FailureOr mlir::bufferization::castOrReallocMemRefValue( for (int i = 0; i < destType.getRank(); ++i) { if (destType.getShape()[i] != ShapedType::kDynamic) continue; - Value size = b.create(loc, value, i); + Value size = memref::DimOp::create(b, loc, value, i); dynamicOperands.push_back(size); } @@ -135,10 +136,10 @@ void mlir::bufferization::populateDynamicDimSizes( for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { if (llvm::isa(shapedType)) { - dynamicDims.push_back(b.create(loc, shapedValue, i)); + dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i)); } else { assert(llvm::isa(shapedType) && "expected tensor"); - dynamicDims.push_back(b.create(loc, shapedValue, i)); + dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i)); } } } @@ -322,8 +323,8 @@ struct ReplaceStaticShapeDims : OpRewritePattern { newShape, op.getType().getElementType(), op.getType().getEncoding()); if (newType == op.getType()) return failure(); - auto newOp = rewriter.create( - op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); + auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType, + newDynamicSizes, /*copy=*/Value()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } @@ -428,7 +429,7 @@ void AllocTensorOp::print(OpAsmPrinter &p) { Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { assert(isDynamicDim(idx) && "expected dynamic dim"); if (getCopy()) - return b.create(getLoc(), getCopy(), idx); + return tensor::DimOp::create(b, getLoc(), getCopy(), idx); return getOperand(getIndexOfDynamicSize(idx)); } @@ -514,8 +515,8 @@ struct SimplifyClones : public OpRewritePattern { } if (source.getType() != cloneOp.getType()) - source = rewriter.create(cloneOp.getLoc(), - cloneOp.getType(), source); + source = memref::CastOp::create(rewriter, cloneOp.getLoc(), + cloneOp.getType(), source); rewriter.replaceOp(cloneOp, source); rewriter.eraseOp(redundantDealloc); return success(); @@ -539,7 +540,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, FailureOr buffer = getBuffer(rewriter, getTensor(), options, state); if (failed(buffer)) return failure(); - rewriter.create(getLoc(), *buffer); + memref::DeallocOp::create(rewriter, getLoc(), *buffer); rewriter.eraseOp(getOperation()); return success(); } @@ -644,8 +645,9 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, assert(getRestrict() && "expected that ops with memrefs dest have 'restrict'"); setRestrict(false); - return builder.create( - loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(), + return ToTensorOp::create( + builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()), + getDest(), /*restrict=*/true, getWritable()); } @@ -807,8 +809,8 @@ struct ToBufferOfCast : public OpRewritePattern { return failure(); auto memrefType = MemRefType::get(srcTensorType.getShape(), srcTensorType.getElementType()); - Value memref = rewriter.create(toBuffer.getLoc(), memrefType, - tensorCastOperand.getOperand()); + Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType, + tensorCastOperand.getOperand()); rewriter.replaceOpWithNewOp(toBuffer, toBuffer.getType(), memref); return success(); @@ -881,12 +883,12 @@ LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter, std::optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc) + return memref::DeallocOp::create(builder, alloc.getLoc(), alloc) .getOperation(); } std::optional CloneOp::buildClone(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc).getResult(); + return CloneOp::create(builder, alloc.getLoc(), alloc).getResult(); } //===----------------------------------------------------------------------===// @@ -960,7 +962,7 @@ struct DeallocRemoveDuplicateDeallocMemrefs Value &newCond = newConditions[memrefToCondition[memref]]; if (newCond != cond) newCond = - rewriter.create(deallocOp.getLoc(), newCond, cond); + arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond); } else { memrefToCondition.insert({memref, newConditions.size()}); newMemrefs.push_back(memref); @@ -1015,8 +1017,8 @@ struct DeallocRemoveDuplicateRetainedMemrefs // We need to create a new op because the number of results is always the // same as the number of condition operands. auto newDeallocOp = - rewriter.create(deallocOp.getLoc(), deallocOp.getMemrefs(), - deallocOp.getConditions(), newRetained); + DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(), + deallocOp.getConditions(), newRetained); SmallVector replacements( llvm::map_range(resultReplacementIdx, [&](unsigned idx) { return newDeallocOp.getUpdatedConditions()[idx]; @@ -1037,8 +1039,8 @@ struct EraseEmptyDealloc : public OpRewritePattern { LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { if (deallocOp.getMemrefs().empty()) { - Value constFalse = rewriter.create( - deallocOp.getLoc(), rewriter.getBoolAttr(false)); + Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(), + rewriter.getBoolAttr(false)); rewriter.replaceOp( deallocOp, SmallVector(deallocOp.getUpdatedConditions().size(), constFalse)); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index db1eb20512033..7f495b0ac164c 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -70,12 +70,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, *getFunctionBoundaryTypeConversion()); if (getMemcpyOp() == "memref.copy") { options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { - b.create(loc, from, to); + memref::CopyOp::create(b, loc, from, to); return success(); }; } else if (getMemcpyOp() == "linalg.copy") { options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { - b.create(loc, from, to); + linalg::CopyOp::create(b, loc, from, to); return success(); }; } else { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index c5fab80ecaa08..835ad61f54da2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -167,7 +167,7 @@ struct RemoveDeallocMemrefsContainedInRetained std::optional analysisResult = analysis.isSameAllocation(retained, memref); if (analysisResult == true) { - auto disjunction = rewriter.create( + auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(), updatedCondition, cond); rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(), disjunction); @@ -247,14 +247,14 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias continue; } - replacements.push_back(rewriter.create( + replacements.push_back(arith::ConstantOp::create(rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false))); } if (newRetainedMemrefs.size() == deallocOp.getRetained().size()) return failure(); - auto newDeallocOp = rewriter.create( + auto newDeallocOp = DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), newRetainedMemrefs); int i = 0; @@ -326,7 +326,7 @@ struct SplitDeallocWhenNotAliasingAnyOther } // Create new bufferization.dealloc op for `memref`. - auto newDeallocOp = rewriter.create(loc, memref, cond, + auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond, deallocOp.getRetained()); updatedConditions.push_back( llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()))); @@ -337,7 +337,7 @@ struct SplitDeallocWhenNotAliasingAnyOther return failure(); // Create bufferization.dealloc op for all remaining memrefs. - auto newDeallocOp = rewriter.create( + auto newDeallocOp = DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions, deallocOp.getRetained()); // Bit-or all conditions. @@ -347,7 +347,7 @@ struct SplitDeallocWhenNotAliasingAnyOther assert(replacements.size() == additionalConditions.size() && "expected same number of updated conditions"); for (int64_t i = 0, e = replacements.size(); i < e; ++i) { - replacements[i] = rewriter.create( + replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i], additionalConditions[i]); } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 15e03fbefe9c5..9f11f23a5dfeb 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -133,7 +133,7 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, return WalkResult::interrupt(); } } - builder.create(op.getLoc(), keepAsReturnOperands); + func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands); op.erase(); return WalkResult::advance(); }); @@ -191,7 +191,7 @@ updateCalls(ModuleOp module, assert(hasFullyDynamicLayoutMap(memrefType) && "layout map not supported"); outParam = - builder.create(op.getLoc(), memrefType, outParam); + memref::CastOp::create(builder, op.getLoc(), memrefType, outParam); } memref.replaceAllUsesWith(outParam); outParams.push_back(outParam); @@ -201,7 +201,7 @@ updateCalls(ModuleOp module, newOperands.append(outParams.begin(), outParams.end()); auto newResultTypes = llvm::to_vector<6>(llvm::map_range( replaceWithNewCallResults, [](Value v) { return v.getType(); })); - auto newCall = builder.create(op.getLoc(), op.getCalleeAttr(), + auto newCall = func::CallOp::create(builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands); for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index ff2c83d228dbb..1e02d9bfa7ef6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -146,7 +146,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, cast(getMemRefTypeWithStaticIdentityLayout(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); - auto global = globalBuilder.create( + auto global = memref::GlobalOp::create(globalBuilder, constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/memrefType, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 6472ef3eff2ac..c7343c62e2cc8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -436,7 +436,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, // Replace all uses of the original tensor bbArg. rewriter.setInsertionPointToStart(block); if (!bbArgUses.empty()) { - Value toTensorOp = rewriter.create( + Value toTensorOp = bufferization::ToTensorOp::create(rewriter, bbArg.getLoc(), tensorType, bbArg); for (OpOperand *use : bbArgUses) use->set(toTensorOp); @@ -468,12 +468,12 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, if (failed(operandBufferType)) return failure(); rewriter.setInsertionPointAfterValue(operand); - Value bufferizedOperand = rewriter.create( + Value bufferizedOperand = bufferization::ToBufferOp::create(rewriter, operand.getLoc(), *operandBufferType, operand); // A cast is needed if the operand and the block argument have different // bufferized types. if (type != *operandBufferType) - bufferizedOperand = rewriter.create( + bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(), type, bufferizedOperand); newOperands.push_back(bufferizedOperand); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index ae011904cb972..9014b8b8ecdf0 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -120,7 +120,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { // Update function calls. for (func::CallOp callOp : callerMap[funcOp]) { rewriter.setInsertionPoint(callOp); - auto newCallOp = rewriter.create(callOp.getLoc(), funcOp, + auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp, callOp.getOperands()); SmallVector newResults; int64_t nextResult = 0; @@ -136,7 +136,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { Type expectedType = callOp.getResult(i).getType(); if (replacement.getType() != expectedType) { // A cast must be inserted at the call site. - replacement = rewriter.create( + replacement = memref::CastOp::create(rewriter, callOp.getLoc(), expectedType, replacement); } newResults.push_back(replacement); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 6f27563a45548..202b9eb0c69dc 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -169,7 +169,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( cast(v.getType()).getElementType()) continue; rewriter.setInsertionPointAfterValue(replacement); - replacement = rewriter.create(v.getLoc(), v.getType(), + replacement = tensor::CastOp::create(rewriter, v.getLoc(), v.getType(), replacement); } // Replace the specific use of the tensor::EmptyOp. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index bd2aebca68079..4d25ca87da888 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -320,7 +320,7 @@ struct CallOpInterface } // 3. Create the new CallOp. - Operation *newCallOp = rewriter.create( + Operation *newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); @@ -484,7 +484,7 @@ struct FuncOpInterface // Note: If `inferFunctionResultLayout = true`, casts are later folded // away. - Value toBufferOp = rewriter.create( + Value toBufferOp = bufferization::ToBufferOp::create(rewriter, returnOp.getLoc(), bufferizedType, returnVal); returnValues.push_back(toBufferOp); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp index 2a17ae4f6a249..304c073a26c73 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp @@ -65,8 +65,8 @@ class DeallocOpConversion rewriter.replaceOpWithNewOp( op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { - builder.create(loc, adaptor.getMemrefs()[0]); - builder.create(loc); + memref::DeallocOp::create(builder, loc, adaptor.getMemrefs()[0]); + scf::YieldOp::create(builder, loc); }); return success(); } @@ -109,13 +109,13 @@ class DeallocOpConversion // Compute the base pointer indices, compare all retained indices to the // memref index to check if they alias. SmallVector doesNotAliasList; - Value memrefAsIdx = rewriter.create( + Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, op->getLoc(), adaptor.getMemrefs()[0]); for (Value retained : adaptor.getRetained()) { Value retainedAsIdx = - rewriter.create(op->getLoc(), + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, op->getLoc(), retained); - Value doesNotAlias = rewriter.create( + Value doesNotAlias = arith::CmpIOp::create(rewriter, op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); doesNotAliasList.push_back(doesNotAlias); } @@ -123,17 +123,17 @@ class DeallocOpConversion // AND-reduce the list of booleans from above. Value prev = doesNotAliasList.front(); for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) - prev = rewriter.create(op->getLoc(), prev, doesNotAlias); + prev = arith::AndIOp::create(rewriter, op->getLoc(), prev, doesNotAlias); // Also consider the condition given by the dealloc operation and perform a // conditional deallocation guarded by that value. - Value shouldDealloc = rewriter.create( + Value shouldDealloc = arith::AndIOp::create(rewriter, op->getLoc(), prev, adaptor.getConditions()[0]); - rewriter.create( + scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { - builder.create(loc, adaptor.getMemrefs()[0]); - builder.create(loc); + memref::DeallocOp::create(builder, loc, adaptor.getMemrefs()[0]); + scf::YieldOp::create(builder, loc); }); // Compute the replacement values for the dealloc operation results. This @@ -141,12 +141,12 @@ class DeallocOpConversion // `select(does_alias_with_memref(r), memref_cond, false)` for each retained // value r. SmallVector replacements; - Value trueVal = rewriter.create( + Value trueVal = arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getBoolAttr(true)); for (Value doesNotAlias : doesNotAliasList) { Value aliases = - rewriter.create(op->getLoc(), doesNotAlias, trueVal); - Value result = rewriter.create(op->getLoc(), aliases, + arith::XOrIOp::create(rewriter, op->getLoc(), doesNotAlias, trueVal); + Value result = arith::AndIOp::create(rewriter, op->getLoc(), aliases, adaptor.getConditions()[0]); replacements.push_back(result); } @@ -231,18 +231,18 @@ class DeallocOpConversion // Without storing them to memrefs, we could not use for-loops but only a // completely unrolled version of it, potentially leading to code-size // blow-up. - Value toDeallocMemref = rewriter.create( + Value toDeallocMemref = memref::AllocOp::create(rewriter, op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, rewriter.getIndexType())); - Value conditionMemref = rewriter.create( + Value conditionMemref = memref::AllocOp::create(rewriter, op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, rewriter.getI1Type())); - Value toRetainMemref = rewriter.create( + Value toRetainMemref = memref::AllocOp::create(rewriter, op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, rewriter.getIndexType())); auto getConstValue = [&](uint64_t value) -> Value { - return rewriter.create(op.getLoc(), + return arith::ConstantOp::create(rewriter, op.getLoc(), rewriter.getIndexAttr(value)); }; @@ -250,58 +250,58 @@ class DeallocOpConversion // at runtime. for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { Value memrefAsIdx = - rewriter.create(op.getLoc(), + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, op.getLoc(), toDealloc); - rewriter.create(op.getLoc(), memrefAsIdx, + memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx, toDeallocMemref, getConstValue(i)); } for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) - rewriter.create(op.getLoc(), cond, conditionMemref, + memref::StoreOp::create(rewriter, op.getLoc(), cond, conditionMemref, getConstValue(i)); for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { Value memrefAsIdx = - rewriter.create(op.getLoc(), + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, op.getLoc(), toRetain); - rewriter.create(op.getLoc(), memrefAsIdx, toRetainMemref, + memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx, toRetainMemref, getConstValue(i)); } // Cast the allocated memrefs to dynamic shape because we want only one // helper function no matter how many operands the bufferization.dealloc // has. - Value castedDeallocMemref = rewriter.create( + Value castedDeallocMemref = memref::CastOp::create(rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), toDeallocMemref); - Value castedCondsMemref = rewriter.create( + Value castedCondsMemref = memref::CastOp::create(rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), conditionMemref); - Value castedRetainMemref = rewriter.create( + Value castedRetainMemref = memref::CastOp::create(rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), toRetainMemref); - Value deallocCondsMemref = rewriter.create( + Value deallocCondsMemref = memref::AllocOp::create(rewriter, op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, rewriter.getI1Type())); - Value retainCondsMemref = rewriter.create( + Value retainCondsMemref = memref::AllocOp::create(rewriter, op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, rewriter.getI1Type())); - Value castedDeallocCondsMemref = rewriter.create( + Value castedDeallocCondsMemref = memref::CastOp::create(rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), deallocCondsMemref); - Value castedRetainCondsMemref = rewriter.create( + Value castedRetainCondsMemref = memref::CastOp::create(rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), retainCondsMemref); Operation *symtableOp = op->getParentWithTrait(); - rewriter.create( + func::CallOp::create(rewriter, op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), SmallVector{castedDeallocMemref, castedRetainMemref, castedCondsMemref, castedDeallocCondsMemref, @@ -309,30 +309,30 @@ class DeallocOpConversion for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { Value idxValue = getConstValue(i); - Value shouldDealloc = rewriter.create( + Value shouldDealloc = memref::LoadOp::create(rewriter, op.getLoc(), deallocCondsMemref, idxValue); - rewriter.create( + scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { - builder.create(loc, adaptor.getMemrefs()[i]); - builder.create(loc); + memref::DeallocOp::create(builder, loc, adaptor.getMemrefs()[i]); + scf::YieldOp::create(builder, loc); }); } SmallVector replacements; for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { Value idxValue = getConstValue(i); - Value ownership = rewriter.create( + Value ownership = memref::LoadOp::create(rewriter, op.getLoc(), retainCondsMemref, idxValue); replacements.push_back(ownership); } // Deallocate above allocated memrefs again to avoid memory leaks. // Deallocation will not be run on code after this stage. - rewriter.create(op.getLoc(), toDeallocMemref); - rewriter.create(op.getLoc(), toRetainMemref); - rewriter.create(op.getLoc(), conditionMemref); - rewriter.create(op.getLoc(), deallocCondsMemref); - rewriter.create(op.getLoc(), retainCondsMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), toDeallocMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), toRetainMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), conditionMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), deallocCondsMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), retainCondsMemref); rewriter.replaceOp(op, replacements); return success(); @@ -350,7 +350,7 @@ class DeallocOpConversion ConversionPatternRewriter &rewriter) const override { // Lower the trivial case. if (adaptor.getMemrefs().empty()) { - Value falseVal = rewriter.create( + Value falseVal = arith::ConstantOp::create(rewriter, op.getLoc(), rewriter.getBoolAttr(false)); rewriter.replaceOp( op, SmallVector(adaptor.getRetained().size(), falseVal)); @@ -450,30 +450,30 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( Value retainCondsMemref = helperFuncOp.getArguments()[4]; // Insert some prerequisites. - Value c0 = builder.create(loc, builder.getIndexAttr(0)); - Value c1 = builder.create(loc, builder.getIndexAttr(1)); + Value c0 = arith::ConstantOp::create(builder, loc, builder.getIndexAttr(0)); + Value c1 = arith::ConstantOp::create(builder, loc, builder.getIndexAttr(1)); Value trueValue = - builder.create(loc, builder.getBoolAttr(true)); + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); Value falseValue = - builder.create(loc, builder.getBoolAttr(false)); - Value toDeallocSize = builder.create(loc, toDeallocMemref, c0); - Value toRetainSize = builder.create(loc, toRetainMemref, c0); + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(false)); + Value toDeallocSize = memref::DimOp::create(builder, loc, toDeallocMemref, c0); + Value toRetainSize = memref::DimOp::create(builder, loc, toRetainMemref, c0); - builder.create( + scf::ForOp::create(builder, loc, c0, toRetainSize, c1, ValueRange(), [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - builder.create(loc, falseValue, retainCondsMemref, i); - builder.create(loc); + memref::StoreOp::create(builder, loc, falseValue, retainCondsMemref, i); + scf::YieldOp::create(builder, loc); }); - builder.create( + scf::ForOp::create(builder, loc, c0, toDeallocSize, c1, ValueRange(), [&](OpBuilder &builder, Location loc, Value outerIter, ValueRange iterArgs) { Value toDealloc = - builder.create(loc, toDeallocMemref, outerIter); + memref::LoadOp::create(builder, loc, toDeallocMemref, outerIter); Value cond = - builder.create(loc, conditionMemref, outerIter); + memref::LoadOp::create(builder, loc, conditionMemref, outerIter); // Build the first for loop that computes aliasing with retained // memrefs. @@ -483,31 +483,31 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( loc, c0, toRetainSize, c1, trueValue, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - Value retainValue = builder.create( + Value retainValue = memref::LoadOp::create(builder, loc, toRetainMemref, i); - Value doesAlias = builder.create( + Value doesAlias = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, retainValue, toDealloc); - builder.create( + scf::IfOp::create(builder, loc, doesAlias, [&](OpBuilder &builder, Location loc) { Value retainCondValue = - builder.create( + memref::LoadOp::create(builder, loc, retainCondsMemref, i); Value aggregatedRetainCond = - builder.create( + arith::OrIOp::create(builder, loc, retainCondValue, cond); - builder.create( + memref::StoreOp::create(builder, loc, aggregatedRetainCond, retainCondsMemref, i); - builder.create(loc); + scf::YieldOp::create(builder, loc); }); - Value doesntAlias = builder.create( + Value doesntAlias = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, retainValue, toDealloc); - Value yieldValue = builder.create( + Value yieldValue = arith::AndIOp::create(builder, loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); + scf::YieldOp::create(builder, loc, yieldValue); }) .getResult(0); @@ -519,24 +519,24 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( loc, c0, outerIter, c1, noRetainAlias, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - Value prevDeallocValue = builder.create( + Value prevDeallocValue = memref::LoadOp::create(builder, loc, toDeallocMemref, i); - Value doesntAlias = builder.create( + Value doesntAlias = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, prevDeallocValue, toDealloc); - Value yieldValue = builder.create( + Value yieldValue = arith::AndIOp::create(builder, loc, iterArgs[0], doesntAlias); - builder.create(loc, yieldValue); + scf::YieldOp::create(builder, loc, yieldValue); }) .getResult(0); - Value shouldDealoc = builder.create(loc, noAlias, cond); - builder.create(loc, shouldDealoc, deallocCondsMemref, + Value shouldDealoc = arith::AndIOp::create(builder, loc, noAlias, cond); + memref::StoreOp::create(builder, loc, shouldDealoc, deallocCondsMemref, outerIter); - builder.create(loc); + scf::YieldOp::create(builder, loc); }); - builder.create(loc); + func::ReturnOp::create(builder, loc); return helperFuncOp; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 1eeafc4df8cf1..71890bfea4e5d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -43,7 +43,7 @@ using namespace mlir::bufferization; //===----------------------------------------------------------------------===// static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { - return builder.create(loc, builder.getBoolAttr(value)); + return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value)); } static bool isMemref(Value v) { return isa(v.getType()); } @@ -755,12 +755,12 @@ Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership( .create( memref.getLoc(), condition, [&](OpBuilder &builder, Location loc) { - builder.create(loc, newMemref); + scf::YieldOp::create(builder, loc, newMemref); }, [&](OpBuilder &builder, Location loc) { Value clone = - builder.create(loc, newMemref); - builder.create(loc, clone); + bufferization::CloneOp::create(builder, loc, newMemref); + scf::YieldOp::create(builder, loc, clone); }) .getResult(0); Value trueVal = buildBoolValue(builder, memref.getLoc(), true); @@ -797,7 +797,7 @@ BufferDeallocation::handleInterface(BranchOpInterface op) { state.getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands, toRetain); - auto deallocOp = builder.create( + auto deallocOp = bufferization::DeallocOp::create(builder, op.getLoc(), memrefs, conditions, toRetain); // We want to replace the current ownership of the retained values with the @@ -885,10 +885,10 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) { builder.setInsertionPoint(op); Ownership ownership = state.getOwnership(operand, block); if (ownership.isUnique()) { - Value ownershipInverted = builder.create( + Value ownershipInverted = arith::XOrIOp::create(builder, op.getLoc(), ownership.getIndicator(), buildBoolValue(builder, op.getLoc(), true)); - builder.create( + cf::AssertOp::create(builder, op.getLoc(), ownershipInverted, "expected that the block does not have ownership"); } diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp index 0bdcf434e062f..b679bbe05cb2f 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -49,7 +49,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder, Type type, Location loc) { if (complex::ConstantOp::isBuildableWith(value, type)) { - return builder.create(loc, type, + return complex::ConstantOp::create(builder, loc, type, llvm::cast(value)); } return arith::ConstantOp::materialize(builder, value, type, loc); diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index fb97045687d65..34008aa5c4d6a 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -20,9 +21,7 @@ using namespace mlir::complex; // ConstantOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index 0c11c76cf1f71..66593783a14f1 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -89,7 +90,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { return failure(); } -// This side effect models "program termination". +// This side effect models "program termination". void AssertOp::getEffects( SmallVectorImpl> &effects) { @@ -312,8 +313,9 @@ struct SimplifyCondBranchIdenticalSuccessors if (std::get<0>(it) == std::get<1>(it)) mergedOperands.push_back(std::get<0>(it)); else - mergedOperands.push_back(rewriter.create( - condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); + mergedOperands.push_back( + arith::SelectOp::create(rewriter, condbr.getLoc(), condition, + std::get<0>(it), std::get<1>(it))); } rewriter.replaceOpWithNewOp(condbr, trueDest, mergedOperands); @@ -412,8 +414,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern { replaced = true; if (!constantTrue) - constantTrue = rewriter.create( - condbr.getLoc(), ty, rewriter.getBoolAttr(true)); + constantTrue = arith::ConstantOp::create( + rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true)); rewriter.modifyOpInPlace(use.getOwner(), [&] { use.set(constantTrue); }); @@ -427,8 +429,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern { replaced = true; if (!constantFalse) - constantFalse = rewriter.create( - condbr.getLoc(), ty, rewriter.getBoolAttr(false)); + constantFalse = arith::ConstantOp::create( + rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false)); rewriter.modifyOpInPlace(use.getOwner(), [&] { use.set(constantFalse); }); diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp index a077f56f4f472..23f50930cb3b3 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -87,7 +87,7 @@ struct CondBranchOpInterface destOperands.getAsOperandRange(), toRetain); SmallVector adaptedConditions( llvm::map_range(conditions, conditionModifier)); - auto deallocOp = builder.create( + auto deallocOp = bufferization::DeallocOp::create(builder, condBr.getLoc(), memrefs, adaptedConditions, toRetain); state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock()); for (auto [retained, ownership] : llvm::zip( @@ -115,18 +115,18 @@ struct CondBranchOpInterface DeallocOp thenTakenDeallocOp = insertDeallocForBranch( condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(), [&](Value cond) { - return builder.create(condBr.getLoc(), cond, + return arith::AndIOp::create(builder, condBr.getLoc(), cond, condBr.getCondition()); }, thenMapping); DeallocOp elseTakenDeallocOp = insertDeallocForBranch( condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(), [&](Value cond) { - Value trueVal = builder.create( + Value trueVal = arith::ConstantOp::create(builder, condBr.getLoc(), builder.getBoolAttr(true)); - Value negation = builder.create( + Value negation = arith::XOrIOp::create(builder, condBr.getLoc(), trueVal, condBr.getCondition()); - return builder.create(condBr.getLoc(), cond, negation); + return arith::AndIOp::create(builder, condBr.getLoc(), cond, negation); }, elseMapping); @@ -143,7 +143,7 @@ struct CondBranchOpInterface for (Value retained : commonValues) { state.resetOwnerships(retained, condBr->getBlock()); - Value combinedOwnership = builder.create( + Value combinedOwnership = arith::SelectOp::create(builder, condBr.getLoc(), condBr.getCondition(), thenMapping[retained], elseMapping[retained]); state.updateOwnership(retained, combinedOwnership, condBr->getBlock()); diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp index 02c41b4fe8113..4dbd1659e24dc 100644 --- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp +++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" using namespace mlir; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 27298e892e599..432c90836c66d 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/STLExtras.h" @@ -50,13 +51,13 @@ void EmitCDialect::initialize() { Operation *EmitCDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return emitc::ConstantOp::create(builder, loc, type, value); } /// Default callback for builders of ops carrying a region. Inserts a yield /// without arguments. void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { - builder.create(loc); + emitc::YieldOp::create(builder, loc); } bool mlir::emitc::isSupportedEmitCType(Type type) { diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp index a578a86b499a6..7e04b4b196a6d 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -25,7 +25,7 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) { Location loc = op->getLoc(); builder.setInsertionPointAfter(op); - auto expressionOp = builder.create(loc, resultType); + auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType); // Replace all op's uses with the new expression's result. result.replaceAllUsesWith(expressionOp.getResult()); @@ -34,7 +34,7 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) { Region ®ion = expressionOp.getRegion(); Block &block = region.emplaceBlock(); builder.setInsertionPointToEnd(&block); - auto yieldOp = builder.create(loc, result); + auto yieldOp = emitc::YieldOp::create(builder, loc, result); // Move op into the new expression. op->moveBefore(yieldOp); diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp index 72c8fd0f32485..ab7be8d6cedd9 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp @@ -21,7 +21,7 @@ Value materializeAsUnrealizedCast(OpBuilder &builder, Type resultType, if (inputs.size() != 1) return Value(); - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index 17d436f6df028..72c331ad96607 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -50,7 +50,7 @@ class WrapFuncInClass : public OpRewritePattern { PatternRewriter &rewriter) const override { auto className = funcOp.getSymNameAttr().str() + "Class"; - ClassOp newClassOp = rewriter.create(funcOp.getLoc(), className); + ClassOp newClassOp = ClassOp::create(rewriter, funcOp.getLoc(), className); SmallVector> fields; rewriter.createBlock(&newClassOp.getBody()); @@ -67,7 +67,7 @@ class WrapFuncInClass : public OpRewritePattern { TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - rewriter.create(funcOp.getLoc(), fieldName, typeAttr, + emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr, argAttr); } @@ -75,7 +75,7 @@ class WrapFuncInClass : public OpRewritePattern { FunctionType funcType = funcOp.getFunctionType(); Location loc = funcOp.getLoc(); FuncOp newFuncOp = - rewriter.create(loc, ("execute"), funcType); + emitc::FuncOp::create(rewriter, loc, ("execute"), funcType); rewriter.createBlock(&newFuncOp.getBody()); newFuncOp.getBody().takeBody(funcOp.getBody()); @@ -85,7 +85,7 @@ class WrapFuncInClass : public OpRewritePattern { newArguments.reserve(fields.size()); for (auto &[fieldName, attr] : fields) { GetFieldOp arg = - rewriter.create(loc, attr.getValue(), fieldName); + emitc::GetFieldOp::create(rewriter, loc, attr.getValue(), fieldName); newArguments.push_back(arg); } diff --git a/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp b/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp index 3328d58551bff..7485ba40b4344 100644 --- a/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp +++ b/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp @@ -61,7 +61,7 @@ struct FuncInlinerInterface : public DialectInlinerInterface { // Replace the return with a branch to the dest. OpBuilder builder(op); - builder.create(op->getLoc(), newDest, returnOp.getOperands()); + cf::BranchOp::create(builder, op->getLoc(), newDest, returnOp.getOperands()); op->erase(); } diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index 9d317f20521fb..1ac11f6a774f3 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -54,8 +55,8 @@ void FuncDialect::initialize() { Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, - llvm::cast(value)); + return ConstantOp::create(builder, loc, type, + llvm::cast(value)); return nullptr; } @@ -143,9 +144,7 @@ LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { - return getValueAttr(); -} +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp index 3adbf092742be..822c21e90ca21 100644 --- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -171,8 +172,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, } } - auto callOp = rewriter.create(insertionPoint->getLoc(), - targetFunction, inputs); + auto callOp = func::CallOp::create(rewriter, insertionPoint->getLoc(), + targetFunction, inputs); // Cast the call results back to the expected types. If any conversions fail // this is a definite failure as the call has been constructed at this point. diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index a3638c8766a5c..4b70cf172b407 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -46,7 +46,7 @@ struct CallOpSignatureConversion : public OpConversionPattern { // Substitute with the new result types from the corresponding FuncType // conversion. - auto newCallOp = rewriter.create( + auto newCallOp = CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(), convertedResults, flattenValues(adaptor.getOperands())); SmallVector replacements; diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index 0e9662689ef78..1eba777921828 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -44,7 +44,7 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, for (unsigned int idx : newResultsOrder) newOutputTypes.push_back(origOutputTypes[idx]); rewriter.setInsertionPoint(funcOp); - auto newFuncOp = rewriter.create( + auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(newInputTypes, newOutputTypes)); @@ -80,7 +80,7 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, newReturnValues.push_back(returnOp.getOperand(idx)); rewriter.setInsertionPoint(returnOp); auto newReturnOp = - rewriter.create(newFuncOp.getLoc(), newReturnValues); + func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues); newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary()); rewriter.eraseOp(returnOp); @@ -109,7 +109,7 @@ func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp, // Replace the kernel call operation with a new one that has the // reordered arguments. rewriter.setInsertionPoint(callOp); - auto newCallOp = rewriter.create( + auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues); newCallOp.setNoInlineAttr(callOp.getNoInlineAttr()); for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder)) diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 30b5ac9809139..8d6962d2d2be8 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -1137,7 +1138,7 @@ struct FoldLaunchArguments : public OpRewritePattern { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&op.getBody().front()); zero = - rewriter.create(op.getLoc(), /*value=*/0); + arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0); } rewriter.replaceAllUsesWith(id, zero); simplified = true; @@ -1381,10 +1382,10 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, int32_t offset, int32_t width, ShuffleMode mode) { build(builder, result, value, - builder.create(result.location, - builder.getI32IntegerAttr(offset)), - builder.create(result.location, - builder.getI32IntegerAttr(width)), + arith::ConstantOp::create(builder, result.location, + builder.getI32IntegerAttr(offset)), + arith::ConstantOp::create(builder, result.location, + builder.getI32IntegerAttr(width)), mode); } @@ -1395,10 +1396,10 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, void RotateOp::build(OpBuilder &builder, OperationState &result, Value value, int32_t offset, int32_t width) { build(builder, result, value, - builder.create(result.location, - builder.getI32IntegerAttr(offset)), - builder.create(result.location, - builder.getI32IntegerAttr(width))); + arith::ConstantOp::create(builder, result.location, + builder.getI32IntegerAttr(offset)), + arith::ConstantOp::create(builder, result.location, + builder.getI32IntegerAttr(width))); } LogicalResult RotateOp::verify() { diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index c9e91535df946..26adfade879d8 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -31,6 +31,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Visitors.h" @@ -560,8 +561,8 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( Value predicate; if (originalBasisWasProvided) { for (Value tmpPredicate : builderResult.predicateOps) { - predicate = predicate ? rewriter.create(loc, predicate, - tmpPredicate) + predicate = predicate ? arith::AndIOp::create(rewriter, loc, predicate, + tmpPredicate) : tmpPredicate; } } @@ -573,8 +574,8 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( Block::iterator insertionPoint; if (predicate) { // Step 6.a. If predicated, move at the beginning. - auto ifOp = rewriter.create(loc, predicate, - /*withElseRegion=*/false); + auto ifOp = scf::IfOp::create(rewriter, loc, predicate, + /*withElseRegion=*/false); targetBlock = ifOp.thenBlock(); insertionPoint = ifOp.thenBlock()->begin(); } else { @@ -632,7 +633,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( // the insertion point. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(parentBlock); - zero = rewriter.create(loc, 0); + zero = arith::ConstantIndexOp::create(rewriter, loc, 0); } ForallRewriteResult rewriteResult; @@ -884,7 +885,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( return diag; // Add a syncthreads if needed. TODO: warpsync if (syncAfterDistribute) - rewriter.create(loc); + BarrierOp::create(rewriter, loc); return DiagnosedSilenceableFailure::success(); } @@ -901,7 +902,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( // Create an early zero index value for replacements. Location loc = target->getLoc(); - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { diag = mlir::transform::gpu::mapOneForallToThreadsImpl( diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp index 12b7f39390967..fc4556df9d6f2 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp @@ -82,8 +82,8 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef activeIds, } if (activeMappingSize == availableMappingSize) continue; - Value idx = rewriter.create(loc, activeMappingSize); - Value pred = rewriter.create(loc, arith::CmpIPredicate::ult, + Value idx = arith::ConstantIndexOp::create(rewriter, loc, activeMappingSize); + Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, activeId, idx); predicateOps.push_back(pred); } @@ -104,11 +104,11 @@ static Value buildLinearId(RewriterBase &rewriter, Location loc, bindDims(rewriter.getContext(), tx, ty, tz); bindSymbols(rewriter.getContext(), bdx, bdy); SmallVector vals{ - rewriter.create(loc, indexType, Dimension::x) + ThreadOrBlockIdOp::create(rewriter, loc, indexType, Dimension::x) .getResult(), - rewriter.create(loc, indexType, Dimension::y) + ThreadOrBlockIdOp::create(rewriter, loc, indexType, Dimension::y) .getResult(), - rewriter.create(loc, indexType, Dimension::z) + ThreadOrBlockIdOp::create(rewriter, loc, indexType, Dimension::z) .getResult(), originalBasisOfr[0], originalBasisOfr[1]}; OpFoldResult ofr = affine::makeComposedFoldedAffineApply( @@ -215,9 +215,9 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) { ArrayRef originalBasis) { IndexType indexType = rewriter.getIndexType(); SmallVector ids{ - rewriter.create(loc, indexType, Dimension::x), - rewriter.create(loc, indexType, Dimension::y), - rewriter.create(loc, indexType, Dimension::z)}; + ThreadOrBlockIdOp::create(rewriter, loc, indexType, Dimension::x), + ThreadOrBlockIdOp::create(rewriter, loc, indexType, Dimension::y), + ThreadOrBlockIdOp::create(rewriter, loc, indexType, Dimension::z)}; // In the 3-D mapping case, scale the first dimension by the multiplicity. SmallVector scaledIds = ids; AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); @@ -417,7 +417,7 @@ DiagnosedSilenceableFailure createGpuLaunch( return diag; auto createConst = [&](int dim) { - return rewriter.create(loc, dim); + return arith::ConstantIndexOp::create(rewriter, loc, dim); }; OpBuilder::InsertionGuard guard(rewriter); Value one = createConst(1); @@ -427,10 +427,10 @@ DiagnosedSilenceableFailure createGpuLaunch( Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one; Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one; Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one; - launchOp = rewriter.create(loc, gridSizeX, gridSizeY, gridSizeZ, + launchOp = LaunchOp::create(rewriter, loc, gridSizeX, gridSizeY, gridSizeZ, blkSizeX, blkSizeY, blkSizeZ); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); - rewriter.create(loc); + TerminatorOp::create(rewriter, loc); return DiagnosedSilenceableFailure::success(); } @@ -451,7 +451,7 @@ DiagnosedSilenceableFailure alterGpuLaunch( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(currentBlockdim.x); auto createConstValue = [&](int dim) { - return rewriter.create(currentBlockdim.x.getLoc(), + return arith::ConstantIndexOp::create(rewriter, currentBlockdim.x.getLoc(), dim); }; diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp index d35f72e5a9e26..66e011a793485 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -147,7 +147,7 @@ struct GpuAllReduceRewriter { // Shortcut to create an op from rewriter using loc as the first argument. template T create(Args... args) { - return rewriter.create(loc, std::forward(args)...); + return T::create(rewriter, loc, std::forward(args)...); } // Creates dimension op of type T, with the result casted to int32. diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index 99a91ecd5642c..ab3dcc6a45d11 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -132,7 +132,7 @@ struct GpuAsyncRegionPass::ThreadTokenCallback { } Value createWaitOp(Location loc, Type resultType, ValueRange operands) { - return builder.create(loc, resultType, operands) + return gpu::WaitOp::create(builder, loc, resultType, operands) .getAsyncToken(); } @@ -168,7 +168,7 @@ async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, // Clone executeOp with the extra results. OpBuilder builder(executeOp); - auto newOp = builder.create( + auto newOp = async::ExecuteOp::create(builder, executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/, executeOp.getDependencies(), executeOp.getBodyOperands()); IRMapping mapper; @@ -250,7 +250,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback { builder.setInsertionPointAfter(op); for (auto asyncToken : asyncTokens) tokens.push_back( - builder.create(loc, asyncToken).getResult()); + async::AwaitOp::create(builder, loc, asyncToken).getResult()); // Set `it` after the inserted async.await ops. it = builder.getInsertionPoint(); }) @@ -282,7 +282,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback { // Otherwise, insert a gpu.wait before 'it'. builder.setInsertionPoint(it->getBlock(), it); - auto waitOp = builder.create(loc, Type{}, tokens); + auto waitOp = gpu::WaitOp::create(builder, loc, Type{}, tokens); // If the new waitOp is at the end of an async.execute region, add it to the // worklist. 'operator()(executeOp)' would do the same, but this is faster. diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp index f63af8da28087..1c7fdf357af20 100644 --- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp @@ -64,7 +64,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, OpBuilder::InsertionGuard g(rewriter); setInsertionPointToStart(rewriter, source); newExtractStridedMetadata = - rewriter.create(loc, source); + memref::ExtractStridedMetadataOp::create(rewriter, loc, source); } auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); @@ -110,7 +110,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, auto &&[base, offset, ignore] = getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp); MemRefType retType = inferCastResultType(base, offset); - return rewriter.create(loc, retType, base, offset, + return memref::ReinterpretCastOp::create(rewriter, loc, retType, base, offset, ArrayRef(), ArrayRef()); } diff --git a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp index c40ddd9b15afc..8dfd52e5803f0 100644 --- a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp @@ -28,11 +28,11 @@ struct GpuGlobalIdRewriter : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto dim = op.getDimension(); - auto blockId = rewriter.create(loc, dim); - auto blockDim = rewriter.create(loc, dim); + auto blockId = gpu::BlockIdOp::create(rewriter, loc, dim); + auto blockDim = gpu::BlockDimOp::create(rewriter, loc, dim); // Compute blockId.x * blockDim.x - auto tmp = rewriter.create(op.getLoc(), blockId, blockDim); - auto threadId = rewriter.create(loc, dim); + auto tmp = index::MulOp::create(rewriter, op.getLoc(), blockId, blockDim); + auto threadId = gpu::ThreadIdOp::create(rewriter, loc, dim); // Compute threadId.x + blockId.x * blockDim.x rewriter.replaceOpWithNewOp(op, threadId, tmp); return success(); diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index a64ec8d52daf0..af8b097ecf6a9 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -41,7 +41,7 @@ template static void createForAllDimensions(OpBuilder &builder, Location loc, SmallVectorImpl &values) { for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) - values.push_back(builder.create(loc, builder.getIndexType(), dim)); + values.push_back(OpTy::create(builder, loc, builder.getIndexType(), dim)); } /// Adds operations generating block/thread ids and grid/block dimensions at the @@ -196,7 +196,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp, } FunctionType type = FunctionType::get(launchOp.getContext(), kernelOperandTypes, {}); - auto outlinedFunc = builder.create( + auto outlinedFunc = gpu::GPUFuncOp::create(builder, loc, kernelFnName, type, TypeRange(ValueRange(launchOp.getWorkgroupAttributions())), TypeRange(ValueRange(launchOp.getPrivateAttributions()))); @@ -248,7 +248,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp, if (!terminator) continue; OpBuilder replacer(terminator); - replacer.create(terminator->getLoc()); + gpu::ReturnOp::create(replacer, terminator->getLoc()); terminator->erase(); } @@ -288,7 +288,7 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp, Value asyncToken = launchOp.getAsyncToken(); std::optional clusterSize = launchOp.getClusterSizeOperandValues(); - auto launchFunc = builder.create( + auto launchFunc = gpu::LaunchFuncOp::create(builder, launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), launchOp.getBlockSizeOperandValues(), launchOp.getDynamicSharedMemorySize(), operands, @@ -416,7 +416,7 @@ class GpuKernelOutliningPass // Check if the module already exists in the symbol table if (!kernelModule) { // If not found, create a new GPU module - kernelModule = builder.create(kernelFunc.getLoc(), + kernelModule = gpu::GPUModuleOp::create(builder, kernelFunc.getLoc(), kernelModuleName); } diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp index 14c44f27a6249..4a670373272a7 100644 --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" @@ -34,8 +35,8 @@ static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { auto rank = memRefType.getRank(); SmallVector lbs, ubs, steps; - Value zero = b.create(0); - Value one = b.create(1); + Value zero = arith::ConstantIndexOp::create(b, 0); + Value one = arith::ConstantIndexOp::create(b, 1); // Make sure we have enough loops to use all thread dimensions, these trivial // loops should be outermost and therefore inserted first. @@ -59,8 +60,8 @@ static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { auto indexType = b.getIndexType(); SmallVector threadIds, blockDims; for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) { - threadIds.push_back(b.create(indexType, dim)); - blockDims.push_back(b.create(indexType, dim)); + threadIds.push_back(gpu::ThreadIdOp::create(b, indexType, dim)); + blockDims.push_back(gpu::BlockDimOp::create(b, indexType, dim)); } // Produce the loop nest with copies. @@ -70,8 +71,8 @@ static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { [&](OpBuilder &b, Location loc, ValueRange loopIvs) { ivs.assign(loopIvs.begin(), loopIvs.end()); auto activeIvs = llvm::ArrayRef(ivs).take_back(rank); - Value loaded = b.create(loc, from, activeIvs); - b.create(loc, loaded, to, activeIvs); + Value loaded = memref::LoadOp::create(b, loc, from, activeIvs); + memref::StoreOp::create(b, loc, loaded, to, activeIvs); }); // Map the innermost loops to threads in reverse order. @@ -131,10 +132,10 @@ static void insertCopies(Region ®ion, Location loc, Value from, Value to) { auto b = ImplicitLocOpBuilder::atBlockBegin(loc, ®ion.front()); insertCopyLoops(b, from, to); - b.create(); + gpu::BarrierOp::create(b); b.setInsertionPoint(®ion.front().back()); - b.create(); + gpu::BarrierOp::create(b); insertCopyLoops(b, to, from); } diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp index f8a548af6b3e8..5bc85f236575d 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp @@ -115,7 +115,7 @@ LogicalResult moduleSerializer(GPUModuleOp op, !handler && moduleHandler) handler = moduleHandler; builder.setInsertionPointAfter(op); - builder.create(op.getLoc(), op.getName(), handler, + gpu::BinaryOp::create(builder, op.getLoc(), op.getName(), handler, builder.getArrayAttr(objects)); op->erase(); return success(); diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp index 171e64346f155..5a55a39a7748e 100644 --- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp @@ -48,10 +48,10 @@ struct PromoteShuffleToSwizzlePattern "offset must be in the range [0, 31]"); Location loc = op.getLoc(); - Value res = rewriter.create( + Value res = amdgpu::SwizzleBitModeOp::create(rewriter, loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31, /*orMask=*/0, /*xorMask=*/offsetValue); - Value valid = rewriter.create(loc, 1, /*width*/ 1); + Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1); rewriter.replaceOp(op, {res, valid}); return success(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp index 9f2900214e8b1..08dedf065f900 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp @@ -48,16 +48,16 @@ struct GpuShuffleRewriter : public OpRewritePattern { // Float types must be converted to i64 to extract the bits. if (isa(valueType)) - value = rewriter.create(valueLoc, i64, value); + value = arith::BitcastOp::create(rewriter, valueLoc, i64, value); // Get the low bits by trunc(value). - lo = rewriter.create(valueLoc, i32, value); + lo = arith::TruncIOp::create(rewriter, valueLoc, i32, value); // Get the high bits by trunc(value >> 32). - auto c32 = rewriter.create( + auto c32 = arith::ConstantOp::create(rewriter, valueLoc, rewriter.getIntegerAttr(i64, 32)); - hi = rewriter.create(valueLoc, value, c32); - hi = rewriter.create(valueLoc, i32, hi); + hi = arith::ShRUIOp::create(rewriter, valueLoc, value, c32); + hi = arith::TruncIOp::create(rewriter, valueLoc, i32, hi); // Shuffle the values. ValueRange loRes = @@ -72,21 +72,21 @@ struct GpuShuffleRewriter : public OpRewritePattern { .getResults(); // Convert lo back to i64. - lo = rewriter.create(valueLoc, i64, loRes[0]); + lo = arith::ExtUIOp::create(rewriter, valueLoc, i64, loRes[0]); // Convert hi back to i64. - hi = rewriter.create(valueLoc, i64, hiRes[0]); - hi = rewriter.create(valueLoc, hi, c32); + hi = arith::ExtUIOp::create(rewriter, valueLoc, i64, hiRes[0]); + hi = arith::ShLIOp::create(rewriter, valueLoc, hi, c32); // Obtain the shuffled bits hi | lo. - value = rewriter.create(loc, hi, lo); + value = arith::OrIOp::create(rewriter, loc, hi, lo); // Convert the value back to float. if (isa(valueType)) - value = rewriter.create(valueLoc, valueType, value); + value = arith::BitcastOp::create(rewriter, valueLoc, valueType, value); // Obtain the shuffle validity by combining both validities. - auto validity = rewriter.create(loc, loRes[1], hiRes[1]); + auto validity = arith::AndIOp::create(rewriter, loc, loRes[1], hiRes[1]); // Replace the op. rewriter.replaceOp(op, {value, validity}); diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp index d80578235f3c3..0dbd9713a2a19 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp @@ -55,22 +55,22 @@ struct GpuSubgroupIdRewriter final : OpRewritePattern { Location loc = op->getLoc(); Type indexType = rewriter.getIndexType(); - Value dimX = rewriter.create(loc, gpu::Dimension::x); - Value dimY = rewriter.create(loc, gpu::Dimension::y); - Value tidX = rewriter.create(loc, gpu::Dimension::x); - Value tidY = rewriter.create(loc, gpu::Dimension::y); - Value tidZ = rewriter.create(loc, gpu::Dimension::z); + Value dimX = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x); + Value dimY = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::y); + Value tidX = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); + Value tidY = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::y); + Value tidZ = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::z); - Value dimYxIdZ = rewriter.create(loc, indexType, dimY, tidZ); + Value dimYxIdZ = arith::MulIOp::create(rewriter, loc, indexType, dimY, tidZ); Value dimYxIdZPlusIdY = - rewriter.create(loc, indexType, dimYxIdZ, tidY); + arith::AddIOp::create(rewriter, loc, indexType, dimYxIdZ, tidY); Value dimYxIdZPlusIdYTimesDimX = - rewriter.create(loc, indexType, dimX, dimYxIdZPlusIdY); - Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create( + arith::MulIOp::create(rewriter, loc, indexType, dimX, dimYxIdZPlusIdY); + Value IdXPlusDimYxIdZPlusIdYTimesDimX = arith::AddIOp::create(rewriter, loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX); - Value subgroupSize = rewriter.create( + Value subgroupSize = gpu::SubgroupSizeOp::create(rewriter, loc, rewriter.getIndexType(), /*upper_bound = */ nullptr); - Value subgroupIdOp = rewriter.create( + Value subgroupIdOp = arith::DivUIOp::create(rewriter, loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize); rewriter.replaceOp(op, {subgroupIdOp}); return success(); diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index af9be4cccecfc..d5470de5a4faa 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -80,7 +80,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern { Location loc = op.getLoc(); Value res = - rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(vecTy)); for (unsigned i = 0; i != numNewReductions; ++i) { int64_t startIdx = i * elementsPerShuffle; @@ -91,22 +91,22 @@ struct BreakDownSubgroupReduce final : OpRewritePattern { Value extracted; if (numElems == 1) { extracted = - rewriter.create(loc, op.getValue(), startIdx); + vector::ExtractOp::create(rewriter, loc, op.getValue(), startIdx); } else { - extracted = rewriter.create( + extracted = vector::ExtractStridedSliceOp::create(rewriter, loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems, /*strides=*/1); } - Value reduce = rewriter.create( + Value reduce = gpu::SubgroupReduceOp::create(rewriter, loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), op.getClusterStride()); if (numElems == 1) { - res = rewriter.create(loc, reduce, res, startIdx); + res = vector::InsertOp::create(rewriter, loc, reduce, res, startIdx); continue; } - res = rewriter.create( + res = vector::InsertStridedSliceOp::create(rewriter, loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); } @@ -139,8 +139,8 @@ struct ScalarizeSingleElementReduce final assert(vecTy.getRank() == 1 && "Unexpected vector type"); assert(!vecTy.isScalable() && "Unexpected vector type"); Location loc = op.getLoc(); - Value extracted = rewriter.create(loc, op.getValue(), 0); - Value reduce = rewriter.create( + Value extracted = vector::ExtractOp::create(rewriter, loc, op.getValue(), 0); + Value reduce = gpu::SubgroupReduceOp::create(rewriter, loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), op.getClusterStride()); rewriter.replaceOpWithNewOp(op, vecTy, reduce); @@ -255,14 +255,14 @@ struct ScalarSubgroupReduceToShuffles final auto packFn = [loc, &rewriter, equivIntType, shuffleIntType](Value unpackedVal) -> Value { auto asInt = - rewriter.create(loc, equivIntType, unpackedVal); - return rewriter.create(loc, shuffleIntType, asInt); + arith::BitcastOp::create(rewriter, loc, equivIntType, unpackedVal); + return arith::ExtUIOp::create(rewriter, loc, shuffleIntType, asInt); }; auto unpackFn = [loc, &rewriter, equivIntType, valueTy](Value packedVal) -> Value { auto asInt = - rewriter.create(loc, equivIntType, packedVal); - return rewriter.create(loc, valueTy, asInt); + arith::TruncIOp::create(rewriter, loc, equivIntType, packedVal); + return arith::BitcastOp::create(rewriter, loc, valueTy, asInt); }; rewriter.replaceOp( @@ -327,9 +327,9 @@ struct VectorSubgroupReduceToShuffles final static_cast(elementsPerShuffle), vecTy.getElementType()); Value extendedInput = op.getValue(); if (vecBitwidth < shuffleBitwidth) { - auto zero = rewriter.create( + auto zero = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(extendedVecTy)); - extendedInput = rewriter.create( + extendedInput = vector::InsertStridedSliceOp::create(rewriter, loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); } @@ -338,21 +338,21 @@ struct VectorSubgroupReduceToShuffles final auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { auto asIntVec = - rewriter.create(loc, shuffleVecType, unpackedVal); - return rewriter.create(loc, asIntVec, 0); + vector::BitCastOp::create(rewriter, loc, shuffleVecType, unpackedVal); + return vector::ExtractOp::create(rewriter, loc, asIntVec, 0); }; auto unpackFn = [loc, &rewriter, shuffleVecType, extendedVecTy](Value packedVal) -> Value { auto asIntVec = - rewriter.create(loc, shuffleVecType, packedVal); - return rewriter.create(loc, extendedVecTy, asIntVec); + vector::BroadcastOp::create(rewriter, loc, shuffleVecType, packedVal); + return vector::BitCastOp::create(rewriter, loc, extendedVecTy, asIntVec); }; Value res = createSubgroupShuffleReduction( rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn); if (vecBitwidth < shuffleBitwidth) { - res = rewriter.create( + res = vector::ExtractStridedSliceOp::create(rewriter, loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), /*strides=*/1); } @@ -379,7 +379,7 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, const bool boundCtrl = true; if (ci.clusterSize >= 2) { // Perform reduction between all lanes N <-> N+1. - dpp = rewriter.create( + dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm, rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl); res = vector::makeArithReduction(rewriter, loc, @@ -388,7 +388,7 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, if (ci.clusterSize >= 4) { // Perform reduction between all lanes N <-> N+2. - dpp = rewriter.create( + dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm, rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl); res = vector::makeArithReduction(rewriter, loc, @@ -397,7 +397,7 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, if (ci.clusterSize >= 8) { // Perform reduction between all lanes N <-> 7-N, // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4]. - dpp = rewriter.create( + dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror, rewriter.getUnitAttr(), allRows, allBanks, boundCtrl); res = vector::makeArithReduction(rewriter, loc, @@ -406,7 +406,7 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, if (ci.clusterSize >= 16) { // Perform reduction between all lanes N <-> 15-N, // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8]. - dpp = rewriter.create( + dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror, rewriter.getUnitAttr(), allRows, allBanks, boundCtrl); res = vector::makeArithReduction(rewriter, loc, @@ -416,7 +416,7 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, if (chipset.majorVersion <= 9) { // Broadcast last value from each row to next row. // Use row mask to avoid polluting rows 1 and 3. - dpp = rewriter.create( + dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15, rewriter.getUnitAttr(), 0xa, allBanks, /*bound_ctrl*/ false); @@ -424,9 +424,9 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, rewriter, loc, gpu::convertReductionKind(mode), res, dpp); } else if (chipset.majorVersion <= 12) { // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2). - Value uint32Max = rewriter.create( + Value uint32Max = arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1)); - dpp = rewriter.create(loc, res.getType(), res, res, + dpp = ROCDL::PermlaneX16Op::create(rewriter, loc, res.getType(), res, res, uint32Max, uint32Max, /*fi=*/true, /*bound_ctrl=*/false); @@ -438,37 +438,37 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, "this device."); } if (ci.subgroupSize == 32) { - Value lane31 = rewriter.create( + Value lane31 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); - res = rewriter.create(loc, res.getType(), res, lane31); + res = ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31); } } if (ci.clusterSize >= 64) { if (chipset.majorVersion <= 9) { // Broadcast 31st lane value to rows 2 and 3. - dpp = rewriter.create( + dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31, rewriter.getUnitAttr(), 0xf, allBanks, /*bound_ctrl*/ true); res = vector::makeArithReduction( rewriter, loc, gpu::convertReductionKind(mode), dpp, res); // Obtain reduction from last rows, the previous rows are polluted. - Value lane63 = rewriter.create( + Value lane63 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); - res = rewriter.create(loc, res.getType(), res, lane63); + res = ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63); } else if (chipset.majorVersion <= 12) { // Assume reduction across 32 lanes has been done. // Perform final reduction manually by summing values in lane 0 and // lane 32. - Value lane31 = rewriter.create( + Value lane31 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); - Value lane63 = rewriter.create( + Value lane63 = arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); lane31 = - rewriter.create(loc, res.getType(), res, lane31); + ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31); lane63 = - rewriter.create(loc, res.getType(), res, lane63); + ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63); res = vector::makeArithReduction( rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63); } else { diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp index 29f6f32892f72..473a2e515c12b 100644 --- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -27,7 +27,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndReplaceReturns( // Create a new op before the existing one, with the extra operands. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(warpOp); - auto newWarpOp = rewriter.create( + auto newWarpOp = WarpExecuteOnLane0Op::create(rewriter, warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); @@ -124,7 +124,7 @@ bool WarpDistributionPattern::delinearizeLaneId( int64_t usedThreads = 1; - Value zero = builder.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); delinearizedIds.assign(sizes.size(), zero); for (int i = sizes.size() - 1; i >= 0; --i) { diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp index bcc9f0b109ac2..d7bc15b866686 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp index 7ec3aa2741023..c62e5c9c50183 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/Dialect/IRDL/IRDLSymbols.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/ValueRange.h" #include diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp index 5c935c5f4b53e..61089d7665e6c 100644 --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -37,7 +38,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, if (auto boolValue = dyn_cast(value)) { if (!type.isSignlessInteger(1)) return nullptr; - return b.create(loc, type, boolValue); + return BoolConstantOp::create(b, loc, type, boolValue); } // Materialize integer attributes as `index`. @@ -47,7 +48,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, return nullptr; assert(indexValue.getValue().getBitWidth() == IndexType::kInternalStorageBitWidth); - return b.create(loc, indexValue); + return ConstantOp::create(b, loc, indexValue); } return nullptr; @@ -130,15 +131,18 @@ canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, auto lhsOp = op.getLhs().template getDefiningOp(); if (!lhsOp) - return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp"); + return rewriter.notifyMatchFailure(op.getLoc(), + "LHS is not the same BinaryOp"); if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant())) - return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant"); + return rewriter.notifyMatchFailure(op.getLoc(), + "RHS of LHS op is not a constant"); Value c = rewriter.createOrFold(op->getLoc(), op.getRhs(), - lhsOp.getRhs()); + lhsOp.getRhs()); if (c.getDefiningOp()) - return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded"); + return rewriter.notifyMatchFailure(op.getLoc(), + "new BinaryOp was not folded"); rewriter.replaceOpWithNewOp(op, lhsOp.getLhs(), c); return success(); @@ -716,11 +720,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { index::CmpOp newCmp; if (rhsIsZero) - newCmp = rewriter.create(op.getLoc(), op.getPred(), - subOp.getLhs(), subOp.getRhs()); + newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(), + subOp.getLhs(), subOp.getRhs()); else - newCmp = rewriter.create(op.getLoc(), op.getPred(), - subOp.getRhs(), subOp.getLhs()); + newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(), + subOp.getRhs(), subOp.getLhs()); rewriter.replaceOp(op, newCmp); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index c17ef1029faf6..b1d376c7c03db 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -90,7 +90,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { } for (auto [idx, t] : llvm::enumerate(stype.getBody())) { if (itype != PTXRegisterMod::Write) { - Value extractValue = rewriter.create( + Value extractValue = LLVM::ExtractValueOp::create(rewriter, interfaceOp->getLoc(), v, idx); addValue(extractValue); } @@ -132,7 +132,7 @@ LLVM::InlineAsmOp PtxBuilder::build() { // Replace all % with $ llvm::replace(ptxInstruction, '%', '$'); - return rewriter.create( + return LLVM::InlineAsmOp::create(rewriter, interfaceOp->getLoc(), /*result types=*/resultTypes, /*operands=*/ptxOperands, diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 89f765dacda35..7d3f59d128367 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -89,7 +89,7 @@ mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, OpBuilder::InsertionGuard g(b); assert(!moduleOp->getRegion(0).empty() && "expected non-empty region"); b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); - auto funcOp = b.create( + auto funcOp = LLVM::LLVMFuncOp::create(b, moduleOp->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 6dcd94e6eea17..e25a90b1d276e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -3318,7 +3319,7 @@ bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) { ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value, Type type, Location loc) { if (isBuildableWith(value, type)) - return builder.create(loc, cast(value)); + return LLVM::ConstantOp::create(builder, loc, cast(value)); return nullptr; } @@ -4284,13 +4285,13 @@ Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, // a builtin zero attribute and thus will materialize as a llvm.mlir.constant. if (auto symbol = dyn_cast(value)) if (isa(type)) - return builder.create(loc, type, symbol); + return LLVM::AddressOfOp::create(builder, loc, type, symbol); if (isa(value)) - return builder.create(loc, type); + return LLVM::UndefOp::create(builder, loc, type); if (isa(value)) - return builder.create(loc, type); + return LLVM::PoisonOp::create(builder, loc, type); if (isa(value)) - return builder.create(loc, type); + return LLVM::ZeroOp::create(builder, loc, type); // Otherwise try materializing it as a regular llvm.mlir.constant op. return LLVM::ConstantOp::materialize(builder, value, type, loc); } @@ -4313,16 +4314,16 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); MLIRContext *ctx = builder.getContext(); auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); - auto global = moduleBuilder.create( - loc, type, /*isConstant=*/true, linkage, name, + auto global = LLVM::GlobalOp::create( + moduleBuilder, loc, type, /*isConstant=*/true, linkage, name, builder.getStringAttr(value), /*alignment=*/0); LLVMPointerType ptrType = LLVMPointerType::get(ctx); // Get the pointer to the first character in the global string. Value globalPtr = - builder.create(loc, ptrType, global.getSymNameAttr()); - return builder.create(loc, ptrType, type, globalPtr, - ArrayRef{0, 0}); + LLVM::AddressOfOp::create(builder, loc, ptrType, global.getSymNameAttr()); + return LLVM::GEPOp::create(builder, loc, ptrType, type, globalPtr, + ArrayRef{0, 0}); } bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index bc451f8b028fc..98fb9b09f0cd5 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -37,7 +37,7 @@ llvm::SmallVector LLVM::AllocaOp::getPromotableSlots() { Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - return builder.create(getLoc(), slot.elemType); + return LLVM::UndefOp::create(builder, getLoc(), slot.elemType); } void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, @@ -45,7 +45,7 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, OpBuilder &builder) { for (Operation *user : getOperation()->getUsers()) if (auto declareOp = llvm::dyn_cast(user)) - builder.create(declareOp.getLoc(), argument, + LLVM::DbgValueOp::create(builder, declareOp.getLoc(), argument, declareOp.getVarInfo(), declareOp.getLocationExpr()); } @@ -89,7 +89,7 @@ DenseMap LLVM::AllocaOp::destructure( for (Attribute index : usedIndices) { Type elemType = destructurableType.getTypeAtIndex(index); assert(elemType && "used index must exist"); - auto subAlloca = builder.create( + auto subAlloca = LLVM::AllocaOp::create(builder, getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, getArraySize()); newAllocators.push_back(subAlloca); @@ -260,13 +260,13 @@ static Value createExtractAndCast(OpBuilder &builder, Location loc, // Truncate the integer if the size of the target is less than the value. if (isBigEndian(dataLayout)) { uint64_t shiftAmount = srcTypeSize - targetTypeSize; - auto shiftConstant = builder.create( + auto shiftConstant = LLVM::ConstantOp::create(builder, loc, builder.getIntegerAttr(srcType, shiftAmount)); replacement = builder.createOrFold(loc, srcValue, shiftConstant); } - replacement = builder.create( + replacement = LLVM::TruncOp::create(builder, loc, builder.getIntegerType(targetTypeSize), replacement); // Now cast the integer to the actual target type if required. @@ -304,7 +304,7 @@ static Value createInsertAndCast(OpBuilder &builder, Location loc, // On big endian systems, a store to the base pointer overwrites the most // significant bits. To accomodate for this, the stored value needs to be // shifted into the according position. - Value bigEndianShift = builder.create( + Value bigEndianShift = LLVM::ConstantOp::create(builder, loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference)); valueAsInt = builder.createOrFold(loc, valueAsInt, bigEndianShift); @@ -325,7 +325,7 @@ static Value createInsertAndCast(OpBuilder &builder, Location loc, } // Mask out the affected bits ... - Value mask = builder.create( + Value mask = LLVM::ConstantOp::create(builder, loc, builder.getIntegerAttr(defAsInt.getType(), maskValue)); Value masked = builder.createOrFold(loc, defAsInt, mask); @@ -644,7 +644,7 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses( // debug local variable info. This allows the debugger to inform the user that // the variable has been optimized out. auto undef = - builder.create(getValue().getLoc(), getValue().getType()); + UndefOp::create(builder, getValue().getLoc(), getValue().getType()); getValueMutable().assign(undef); return DeletionKind::Keep; } @@ -655,7 +655,7 @@ void LLVM::DbgDeclareOp::visitReplacedValues( ArrayRef> definitions, OpBuilder &builder) { for (auto [op, value] : definitions) { builder.setInsertionPointAfter(op); - builder.create(getLoc(), value, getVarInfo(), + LLVM::DbgValueOp::create(builder, getLoc(), value, getVarInfo(), getLocationExpr()); } } @@ -978,7 +978,7 @@ void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace, IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize)) .getResult(); - builder.create(toReplace.getLoc(), subslots.at(index).ptr, + LLVM::MemsetOp::create(builder, toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(), newMemsetSizeValue, toReplace.getIsVolatile()); } @@ -991,7 +991,7 @@ void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace, auto newMemsetSizeValue = IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize); - builder.create( + LLVM::MemsetInlineOp::create(builder, toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(), newMemsetSizeValue, toReplace.getIsVolatile()); } @@ -1063,7 +1063,7 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, APInt memsetVal(/*numBits=*/width, /*val=*/0); for (unsigned loBit = 0; loBit < width; loBit += 8) memsetVal.insertBits(constantPattern.getValue(), loBit); - return builder.create( + return LLVM::ConstantOp::create(builder, op.getLoc(), IntegerAttr::get(intType, memsetVal)); } @@ -1075,14 +1075,14 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, // value and or-ing it with the previous value. uint64_t coveredBits = 8; Value currentValue = - builder.create(op.getLoc(), intType, op.getVal()); + LLVM::ZExtOp::create(builder, op.getLoc(), intType, op.getVal()); while (coveredBits < width) { Value shiftBy = - builder.create(op.getLoc(), intType, coveredBits); + LLVM::ConstantOp::create(builder, op.getLoc(), intType, coveredBits); Value shifted = - builder.create(op.getLoc(), currentValue, shiftBy); + LLVM::ShlOp::create(builder, op.getLoc(), currentValue, shiftBy); currentValue = - builder.create(op.getLoc(), currentValue, shifted); + LLVM::OrOp::create(builder, op.getLoc(), currentValue, shifted); coveredBits *= 2; } @@ -1094,7 +1094,7 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, }) .Case([&](FloatType type) -> Value { Value intVal = buildMemsetValue(type.getWidth()); - return builder.create(op.getLoc(), type, intVal); + return LLVM::BitcastOp::create(builder, op.getLoc(), type, intVal); }) .Default([](Type) -> Value { llvm_unreachable( @@ -1282,7 +1282,7 @@ static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) { template static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, OpBuilder &builder) { - return builder.create(op.getLoc(), slot.elemType, op.getSrc()); + return LLVM::LoadOp::create(builder, op.getLoc(), slot.elemType, op.getSrc()); } template @@ -1309,7 +1309,7 @@ memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, OpBuilder &builder, Value reachingDefinition) { if (op.loadsFrom(slot)) - builder.create(op.getLoc(), reachingDefinition, op.getDst()); + LLVM::StoreOp::create(builder, op.getLoc(), reachingDefinition, op.getDst()); return DeletionKind::Delete; } @@ -1354,10 +1354,10 @@ template void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout, MemcpyLike toReplace, Value dst, Value src, Type toCpy, bool isVolatile) { - Value memcpySize = builder.create( + Value memcpySize = LLVM::ConstantOp::create(builder, toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(), layout.getTypeSize(toCpy))); - builder.create(toReplace.getLoc(), dst, src, memcpySize, + MemcpyLike::create(builder, toReplace.getLoc(), dst, src, memcpySize, isVolatile); } @@ -1367,7 +1367,7 @@ void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout, Value src, Type toCpy, bool isVolatile) { Type lenType = IntegerType::get(toReplace->getContext(), toReplace.getLen().getBitWidth()); - builder.create( + LLVM::MemcpyInlineOp::create(builder, toReplace.getLoc(), dst, src, IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile); } @@ -1409,7 +1409,7 @@ memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallVector gepIndices{ 0, static_cast( cast(index).getValue().getZExtValue())}; - Value subslotPtrInOther = builder.create( + Value subslotPtrInOther = LLVM::GEPOp::create(builder, op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType, isDst ? op.getSrc() : op.getDst(), gepIndices); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index aaf6b0593c2e6..d95b146d27e5c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 9671afd52fa77..0bb8ecf98c533 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp index bd9d3528ceb74..b1676b4e58c39 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp index 9e497829ba723..fbedc714c200b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MathExtras.h" diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp index 6fbb0d24826d0..1b1268da15646 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp @@ -28,13 +28,13 @@ static void addComdat(LLVM::LLVMFuncOp &op, OpBuilder &builder, PatternRewriter::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); comdatOp = - builder.create(module.getLoc(), comdatName); + mlir::LLVM::ComdatOp::create(builder, module.getLoc(), comdatName); symbolTable.insert(comdatOp); } PatternRewriter::InsertionGuard guard(builder); builder.setInsertionPointToStart(&comdatOp.getBody().back()); - auto selectorOp = builder.create( + auto selectorOp = mlir::LLVM::ComdatSelectorOp::create(builder, comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any); op.setComdatAttr(mlir::SymbolRefAttr::get( builder.getContext(), comdatName, diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 7f3afffc9645e..2ecef4b4c8a51 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -113,18 +113,18 @@ handleInlinedAllocas(Operation *call, // scope if some are already present in the body of the caller. This is not // invalid IR, but LLVM cleans these up in InstCombineCalls.cpp, along with // other cases where the stacksave/stackrestore is redundant. - stackPtr = builder.create( + stackPtr = LLVM::StackSaveOp::create(builder, call->getLoc(), LLVM::LLVMPointerType::get(call->getContext())); } builder.setInsertionPointToStart(callerEntryBlock); for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { - auto newConstant = builder.create( + auto newConstant = LLVM::ConstantOp::create(builder, allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize); // Insert a lifetime start intrinsic where the alloca was before moving it. if (shouldInsertLifetime) { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(allocaOp); - builder.create( + LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), allocaOp.getResult()); } @@ -139,10 +139,10 @@ handleInlinedAllocas(Operation *call, continue; builder.setInsertionPoint(block.getTerminator()); if (hasDynamicAlloca) - builder.create(call->getLoc(), stackPtr); + LLVM::StackRestoreOp::create(builder, call->getLoc(), stackPtr); for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { if (shouldInsertLifetime) - builder.create( + LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), allocaOp.getResult()); } @@ -603,15 +603,15 @@ static Value handleByValArgumentInit(OpBuilder &builder, Location loc, OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = &(*argument.getParentRegion()->begin()); builder.setInsertionPointToStart(entryBlock); - Value one = builder.create(loc, builder.getI64Type(), + Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), builder.getI64IntegerAttr(1)); - allocaOp = builder.create( + allocaOp = LLVM::AllocaOp::create(builder, loc, argument.getType(), elementType, one, targetAlignment); } // Copy the pointee to the newly allocated value. - Value copySize = builder.create( + Value copySize = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize)); - builder.create(loc, allocaOp, argument, copySize, + LLVM::MemcpyOp::create(builder, loc, allocaOp, argument, copySize, /*isVolatile=*/false); return allocaOp; } @@ -747,7 +747,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { // Replace the return with a branch to the dest. OpBuilder builder(op); - builder.create(op->getLoc(), returnOp.getOperands(), newDest); + LLVM::BrOp::create(builder, op->getLoc(), returnOp.getOperands(), newDest); op->erase(); } @@ -801,7 +801,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { // and is extremely unlikely to exist in the code prior to inlining, using // this to communicate between this method and `processInlinedCallBlocks`. // TODO: Fix this by refactoring the inliner interface. - auto copyOp = builder.create(call->getLoc(), argument); + auto copyOp = LLVM::SSACopyOp::create(builder, call->getLoc(), argument); if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName())) copyOp->setDiscardableAttr( builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()), diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp index 8dd0c28d98522..32fe2edd6ea76 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp @@ -59,7 +59,7 @@ static void ensureDistinctSuccessors(Block &bb) { terminator->setSuccessor(dummyBlock, position); for (BlockArgument arg : successor.first->getArguments()) dummyBlock->addArgument(arg.getType(), arg.getLoc()); - builder.create(terminator->getLoc(), + LLVM::BrOp::create(builder, terminator->getLoc(), dummyBlock->getArguments(), successor.first); } } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp index 8db32ec1526c4..de6c3043bbcdd 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -59,32 +59,32 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, Type i32Type = rewriter.getI32Type(); // Extend lhs and rhs to fp32. - Value lhs = rewriter.create(loc, f32Type, op.getLhs()); - Value rhs = rewriter.create(loc, f32Type, op.getRhs()); + Value lhs = LLVM::FPExtOp::create(rewriter, loc, f32Type, op.getLhs()); + Value rhs = LLVM::FPExtOp::create(rewriter, loc, f32Type, op.getRhs()); // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. - Value rcp = rewriter.create(loc, f32Type, rhs); - Value approx = rewriter.create(loc, lhs, rcp); + Value rcp = NVVM::RcpApproxFtzF32Op::create(rewriter, loc, f32Type, rhs); + Value approx = LLVM::FMulOp::create(rewriter, loc, lhs, rcp); // Refine the approximation with one Newton iteration: // float refined = approx + (lhs - approx * rhs) * rcp; - Value err = rewriter.create( - loc, approx, rewriter.create(loc, rhs), lhs); - Value refined = rewriter.create(loc, err, rcp, approx); + Value err = LLVM::FMAOp::create(rewriter, + loc, approx, LLVM::FNegOp::create(rewriter, loc, rhs), lhs); + Value refined = LLVM::FMAOp::create(rewriter, loc, err, rcp, approx); // Use refined value if approx is normal (exponent neither all 0 or all 1). - Value mask = rewriter.create( + Value mask = LLVM::ConstantOp::create(rewriter, loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); - Value cast = rewriter.create(loc, i32Type, approx); - Value exp = rewriter.create(loc, i32Type, cast, mask); - Value zero = rewriter.create( + Value cast = LLVM::BitcastOp::create(rewriter, loc, i32Type, approx); + Value exp = LLVM::AndOp::create(rewriter, loc, i32Type, cast, mask); + Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, rewriter.getUI32IntegerAttr(0)); - Value pred = rewriter.create( + Value pred = LLVM::OrOp::create(rewriter, loc, - rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, zero), - rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, mask)); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, exp, zero), + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, exp, mask)); Value result = - rewriter.create(loc, f32Type, pred, approx, refined); + LLVM::SelectOp::create(rewriter, loc, f32Type, pred, approx, refined); // Replace with trucation back to fp16. rewriter.replaceOpWithNewOp(op, op.getType(), result); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3aa6ac3ea0918..51cc1232df891 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -62,10 +62,10 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, return getAsOpFoldResult( TypeSwitch(v.getType()) .Case([&](RankedTensorType t) -> Value { - return builder.create(loc, v, dim); + return tensor::DimOp::create(builder, loc, v, dim); }) .Case([&](MemRefType t) -> Value { - return builder.create(loc, v, dim); + return memref::DimOp::create(builder, loc, v, dim); })); } @@ -77,11 +77,11 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source, ArrayRef strides) { return TypeSwitch(source.getType()) .Case([&](RankedTensorType t) -> Operation * { - return b.create(loc, source, offsets, sizes, + return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes, strides); }) .Case([&](MemRefType type) -> Operation * { - return b.create(loc, source, offsets, sizes, + return memref::SubViewOp::create(b, loc, source, offsets, sizes, strides); }) .Default([&](Type t) -> Operation * { return nullptr; }); @@ -453,35 +453,35 @@ class RegionBuilderHelper { builder.setInsertionPointToEnd(&block); switch (unaryFn) { case UnaryFn::exp: - return builder.create(arg.getLoc(), arg); + return math::ExpOp::create(builder, arg.getLoc(), arg); case UnaryFn::log: - return builder.create(arg.getLoc(), arg); + return math::LogOp::create(builder, arg.getLoc(), arg); case UnaryFn::abs: - return builder.create(arg.getLoc(), arg); + return math::AbsFOp::create(builder, arg.getLoc(), arg); case UnaryFn::ceil: - return builder.create(arg.getLoc(), arg); + return math::CeilOp::create(builder, arg.getLoc(), arg); case UnaryFn::floor: - return builder.create(arg.getLoc(), arg); + return math::FloorOp::create(builder, arg.getLoc(), arg); case UnaryFn::negf: - return builder.create(arg.getLoc(), arg); + return arith::NegFOp::create(builder, arg.getLoc(), arg); case UnaryFn::reciprocal: { Attribute oneAttr = builder.getOneAttr(arg.getType()); - auto one = builder.create(arg.getLoc(), + auto one = arith::ConstantOp::create(builder, arg.getLoc(), ::cast(oneAttr)); - return builder.create(arg.getLoc(), one, arg); + return arith::DivFOp::create(builder, arg.getLoc(), one, arg); } case UnaryFn::round: - return builder.create(arg.getLoc(), arg); + return math::RoundOp::create(builder, arg.getLoc(), arg); case UnaryFn::sqrt: - return builder.create(arg.getLoc(), arg); + return math::SqrtOp::create(builder, arg.getLoc(), arg); case UnaryFn::rsqrt: - return builder.create(arg.getLoc(), arg); + return math::RsqrtOp::create(builder, arg.getLoc(), arg); case UnaryFn::square: - return builder.create(arg.getLoc(), arg, arg); + return arith::MulFOp::create(builder, arg.getLoc(), arg, arg); case UnaryFn::tanh: - return builder.create(arg.getLoc(), arg); + return math::TanhOp::create(builder, arg.getLoc(), arg); case UnaryFn::erf: - return builder.create(arg.getLoc(), arg); + return math::ErfOp::create(builder, arg.getLoc(), arg); } if (emitError) { emitError() << "unsupported unary function"; @@ -516,17 +516,17 @@ class RegionBuilderHelper { switch (binaryFn) { case BinaryFn::add: if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); + return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::sub: if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); + return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) { if (emitError) { emitError() << "unsupported operation: sub with bools"; @@ -534,20 +534,20 @@ class RegionBuilderHelper { } llvm_unreachable("unsupported operation: sub with bools"); } - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); + return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::div: if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); + return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) { if (emitError) { emitError() << "unsupported operation: div with bools"; @@ -555,7 +555,7 @@ class RegionBuilderHelper { } llvm_unreachable("unsupported operation: div with bools"); } - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::div_unsigned: if (!allInteger || allBool) { if (emitError) { @@ -564,30 +564,30 @@ class RegionBuilderHelper { } llvm_unreachable("unsupported operation: unsigned div not on uint"); } - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::min_signed: assert(!allComplex); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::max_unsigned: assert(!allComplex); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::min_unsigned: assert(!allComplex); if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::powf: assert(allFloatingPoint); - return builder.create(arg0.getLoc(), arg0, arg1); + return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1); } if (emitError) { emitError() << "unsupported binary function"; @@ -610,7 +610,7 @@ class RegionBuilderHelper { case TernaryFn::select: if (!headBool && !(tailFloatingPoint || tailInteger)) llvm_unreachable("unsupported non numeric type"); - return builder.create(arg0.getLoc(), arg0, arg1, arg2); + return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2); } if (emitError) { emitError() << "unsupported ternary function"; @@ -639,7 +639,7 @@ class RegionBuilderHelper { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); - builder.create(loc, values); + YieldOp::create(builder, loc, values); } Value constant(const std::string &value) { @@ -647,13 +647,13 @@ class RegionBuilderHelper { builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); - return builder.create(loc, ::cast(valueAttr)); + return arith::ConstantOp::create(builder, loc, ::cast(valueAttr)); } Value index(int64_t dim) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); - return builder.create(builder.getUnknownLoc(), dim); + return IndexOp::create(builder, builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { @@ -749,12 +749,12 @@ struct FoldFillWithTensorReshape : OpRewritePattern { TensorReshapeOp newInit; if constexpr (std::is_same::value) { - newInit = rewriter.create( + newInit = TensorReshapeOp::create(rewriter, loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.getReassociation(), reshapeOp.getOutputShape(), reshapeOp.getStaticOutputShape()); } else { - newInit = rewriter.create(loc, reshapeOp.getResultType(), + newInit = TensorReshapeOp::create(rewriter, loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.getReassociation()); } @@ -786,7 +786,7 @@ struct FoldFillWithPad final : public OpRewritePattern { return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); - auto emptyTensor = rewriter.create( + auto emptyTensor = tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(), padOp.getResultType().getElementType()); Value replacement = @@ -795,7 +795,7 @@ struct FoldFillWithPad final : public OpRewritePattern { ValueRange{emptyTensor}) .getResult(0); if (replacement.getType() != padOp.getResultType()) { - replacement = rewriter.create( + replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(), padOp.getResultType(), replacement); } rewriter.replaceOp(padOp, replacement); @@ -889,7 +889,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern { for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { if (srcPadType.isDynamicDim(i)) { newSizes.push_back( - rewriter.create(loc, srcPadOp.getSource(), i) + tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i) .getResult()); } else { newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i))); @@ -942,7 +942,7 @@ static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter, if (!packOpDest.hasOneUse()) return failure(); - return rewriter.create(packOp.getLoc(), fillOp.getInputs(), + return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(), packOp.getDest()); } @@ -1042,7 +1042,7 @@ struct FoldConcatsOfFill : public OpRewritePattern { concatOp, "not all operands are defined by a compatible fill op"); } - Value outsConcat = rewriter.create( + Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getDim(), allOuts); rewriter.replaceOpWithNewOp( concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat); @@ -1407,13 +1407,13 @@ struct EraseIdentityLinalgOp : public OpRewritePattern { // TODO: unify the two ops? if (sparse_tensor::getSparseTensorEncoding(returnType) || sparse_tensor::getSparseTensorEncoding(resultType)) - returnedArg = rewriter.create( + returnedArg = sparse_tensor::ConvertOp::create(rewriter, linalgOp.getLoc(), resultType, returnedArg); else { if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), resultType)) return failure(); - returnedArg = rewriter.create( + returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(), resultType, returnedArg); } } @@ -1528,7 +1528,7 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, TypeRange{llvm::cast(result.operands.back().getType()) .getElementType()}, payloadOpAttrs); - b.create(result.location, payloadOp->getResults()); + YieldOp::create(b, result.location, payloadOp->getResults()); } ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1945,7 +1945,7 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc, buildGenericRegion(builder, loc, region, inputs, outputs, [](OpBuilder &b, Location loc, ValueRange args) { if (!args.empty()) - b.create(loc, args[0]); + linalg::YieldOp::create(b, loc, args[0]); }); } @@ -2138,7 +2138,7 @@ struct SwapTransposeWithBroadcast : OpRewritePattern { unsigned inputRank = broadcastInputTy.getRank(); for (unsigned i = 0; i < inputRank; ++i) { if (broadcastInputTy.isDynamicDim(i)) { - dims.push_back(rewriter.create(loc, broadcastInput, i) + dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i) ->getResult(0)); } else { dims.push_back(IntegerAttr::get(IndexType::get(ctx), @@ -2147,7 +2147,7 @@ struct SwapTransposeWithBroadcast : OpRewritePattern { } SmallVector transposeResultShapes = applyPermutation(dims, resultPerms); - Value transposeInit = rewriter.create( + Value transposeInit = tensor::EmptyOp::create(rewriter, transposeOp.getLoc(), transposeResultShapes, broadcastInputTy.getElementType()); @@ -2547,7 +2547,7 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern { // continue to propagate as far up the stack as it can go. OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); Value newOperand = - rewriter.create(loc, resultType, outOperand->get()); + tensor::CastOp::create(rewriter, loc, resultType, outOperand->get()); SmallVector newOperands = linalgOp.getDpsInputs(); SmallVector outputOperands(linalgOp.getDpsInits().begin(), linalgOp.getDpsInits().end()); @@ -2560,7 +2560,7 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern { Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); // Create a tensor.cast operation back to the original type. - Value castBack = rewriter.create( + Value castBack = tensor::CastOp::create(rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber)); SmallVector results(newOp->result_begin(), newOp->result_end()); @@ -2653,7 +2653,7 @@ static void createNewOperandWithStaticSizes( changeNeeded = true; // Get the new operand value given its size and element type by // casting it. - Value newOperand = rewriter.create(loc, resultType, src); + Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src); unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } @@ -2718,7 +2718,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern { Type oldType = oldResult.getType(); replacements.push_back( (newType != oldType) - ? rewriter.create(loc, oldType, newResult) + ? tensor::CastOp::create(rewriter, loc, oldType, newResult) : newResult); } rewriter.replaceOp(linalgOp, replacements); @@ -2756,8 +2756,8 @@ SmallVector SoftmaxOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getInputOperandRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); Value source = getInput(); for (auto dim : llvm::seq(0, operandRank)) { loopBounds[dim].offset = zero; @@ -2924,11 +2924,11 @@ static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, "We should have two maps: 1 for the input, 1 for the output"); assert(indexingMaps[0].isIdentity() && "input map should be identity"); - auto genericOp = builder.create( + auto genericOp = linalg::GenericOp::create(builder, loc, output.getType(), input, output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value result = b.create(loc, args[0], args[1]); - b.create(loc, result); + Value result = T::create(b, loc, args[0], args[1]); + linalg::YieldOp::create(b, loc, result); }); return genericOp.getResult(0); } @@ -2947,12 +2947,12 @@ static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, assert(indexingMaps[0].isIdentity() && "input map should be identity"); // Add the affine map for the output argument. indexingMaps.push_back(indexingMaps[0]); - auto genericOp = builder.create( + auto genericOp = linalg::GenericOp::create(builder, loc, input.getType(), ValueRange{input, max}, output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value diff = b.create(loc, args[0], args[1]); - Value result = b.create(loc, diff); - b.create(loc, result); + Value diff = arith::SubFOp::create(b, loc, args[0], args[1]); + Value result = math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, result); }); return genericOp.getResult(0); } @@ -2974,12 +2974,12 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, assert(indexingMaps[0].isIdentity() && "Numerator map should be identity"); // Add the affine map for the output tensor. indexingMaps.push_back(indexingMaps[0]); - auto genericOp = builder.create( + auto genericOp = linalg::GenericOp::create(builder, loc, numerator.getType(), ValueRange{numerator, denominator}, output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value result = b.create(loc, args[0], args[1]); - b.create(loc, result); + Value result = arith::DivFOp::create(b, loc, args[0], args[1]); + linalg::YieldOp::create(b, loc, result); }); return genericOp.getResult(0); } @@ -3015,12 +3015,12 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { Value output = getOutput(); dims.erase(dims.begin() + reductionDim); // Step 1: Compute max along dim. - Value outputReduce = b.create(loc, dims, elementType); + Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType); Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = - b.create(loc, Value{neutralForMaxF}, outputReduce) + linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce) .result(); Value max = reduce(b, loc, input, neutralForMaxFInit, reductionDim); @@ -3032,7 +3032,7 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value zeroInit = - b.create(loc, Value{zero}, outputReduce).result(); + linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result(); Value denominator = reduce(b, loc, numerator, zeroInit, reductionDim); @@ -3153,7 +3153,7 @@ FailureOr WinogradFilterTransformOp::getTiledImplementation( int64_t filterRank = getFilterOperandRank(); SmallVector filterStrides(filterRank, oneAttr); Location loc = getLoc(); - auto filterSlice = builder.create( + auto filterSlice = tensor::ExtractSliceOp::create(builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides); tiledOperands.emplace_back(filterSlice); @@ -3164,7 +3164,7 @@ FailureOr WinogradFilterTransformOp::getTiledImplementation( int64_t outputRank = getOutputOperandRank(); SmallVector outputStrides(outputRank, oneAttr); - auto outputSlice = builder.create( + auto outputSlice = tensor::ExtractSliceOp::create(builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides); tiledOperands.emplace_back(outputSlice); @@ -3333,7 +3333,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]}); int64_t inputRank = getInputOperandRank(); SmallVector inputStrides(inputRank, oneAttr); - auto inputSlice = builder.create( + auto inputSlice = tensor::ExtractSliceOp::create(builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides); tiledOperands.emplace_back(inputSlice); @@ -3344,7 +3344,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, int64_t outputRank = getOutputOperandRank(); SmallVector outputStrides(outputRank, oneAttr); - auto outputSlice = builder.create( + auto outputSlice = tensor::ExtractSliceOp::create(builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides); tiledOperands.emplace_back(outputSlice); @@ -3504,7 +3504,7 @@ FailureOr WinogradOutputTransformOp::getTiledImplementation( sizes[getValueFDim()]}); int64_t valueRank = getValueOperandRank(); SmallVector sliceStrides(valueRank, oneAttr); - auto valueSlice = builder.create( + auto valueSlice = tensor::ExtractSliceOp::create(builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides); tiledOperands.emplace_back(valueSlice); @@ -3515,7 +3515,7 @@ FailureOr WinogradOutputTransformOp::getTiledImplementation( int64_t outputRank = getOutputOperandRank(); SmallVector strides(outputRank, oneAttr); - auto outputSlice = builder.create( + auto outputSlice = tensor::ExtractSliceOp::create(builder, loc, getOutput(), resultOffsets, resultSizes, strides); tiledOperands.emplace_back(outputSlice); @@ -4971,7 +4971,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, llvm::cast(source.getType()).getShape())) { if (ShapedType::isDynamic(value)) mixedSizes.push_back( - b.create(loc, source, index).getResult()); + tensor::DimOp::create(b, loc, source, index).getResult()); else mixedSizes.push_back(b.getIndexAttr(value)); } @@ -4985,7 +4985,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); auto elemType = llvm::cast(source.getType()).getElementType(); - return b.create(loc, mixedSizes, elemType); + return tensor::EmptyOp::create(b, loc, mixedSizes, elemType); } PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, @@ -4996,7 +4996,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, Value transposedDest = createDestinationTensor(b, loc, getSource(), metadata.innerTiles, metadata.innerDimsPos, metadata.outerDimsPerm); - return b.create(loc, getSource(), transposedDest, + return PackOp::create(b, loc, getSource(), transposedDest, metadata.innerDimsPos, metadata.innerTiles, getPaddingValue(), metadata.outerDimsPerm); } @@ -5138,7 +5138,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { if (srcShape != packOp.getSourceType().getShape()) { auto newSrcType = packOp.getSourceType().clone(srcShape); source = - rewriter.create(loc, newSrcType, packOp.getSource()); + tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); RankedTensorType originalResultType = packOp.getDestType(); @@ -5146,7 +5146,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); dest = - rewriter.create(loc, newDestType, packOp.getDest()); + tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest()); } rewriter.modifyOpInPlace(packOp, [&] { packOp.getSourceMutable().assign(source); @@ -5157,7 +5157,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { if (needUpdateDestType) { rewriter.setInsertionPointAfter(packOp); auto castOp = - rewriter.create(loc, originalResultType, packOp); + tensor::CastOp::create(rewriter, loc, originalResultType, packOp); rewriter.replaceAllUsesExcept(packOp, castOp, castOp); } return success(); @@ -5250,7 +5250,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern { // TODO: Strictly speaking, discardable attributes should be _discarded_ at // this point. However, in practice, we use them for things that we'd like // to preserve. Implement a better abstraction. - PackOp newOp = rewriter.create( + PackOp newOp = PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); @@ -5259,7 +5259,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern { Value oldResult = op.getResult(); Value newResult = newOp.getResult(); Value replacement = (newResult.getType() != oldResult.getType()) - ? rewriter.create( + ? tensor::CastOp::create(rewriter, op->getLoc(), oldResult.getType(), newResult) : newResult; @@ -5358,7 +5358,7 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc, for (auto i : llvm::seq(0, srcType.getRank() - innerTileSizes.size())) { if (srcType.isDynamicDim(i)) - mixedSizes.push_back(b.create(loc, source, i).getResult()); + mixedSizes.push_back(tensor::DimOp::create(b, loc, source, i).getResult()); else mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i))); } @@ -5371,7 +5371,7 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc, mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize); auto elemType = srcType.getElementType(); - return b.create(loc, mixedSizes, elemType); + return tensor::EmptyOp::create(b, loc, mixedSizes, elemType); } UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, @@ -5380,7 +5380,7 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, ArrayRef outerPermutation) { PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( *this, innerPermutation, outerPermutation); - return b.create(loc, transposedSource, getDest(), + return UnPackOp::create(b, loc, transposedSource, getDest(), metadata.innerDimsPos, metadata.innerTiles, metadata.outerDimsPerm); } @@ -5454,7 +5454,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, extractSliceUser.getResultType().getRank()) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); - auto newDest = rewriter.create( + auto newDest = tensor::ExtractSliceOp::create(rewriter, unPackOp->getLoc(), unPackOp.getDest(), extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(), extractSliceUser.getMixedStrides()); @@ -5474,16 +5474,16 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, Value source = unPackOp.getSource(); if (srcShape != unPackOp.getSourceType().getShape()) { auto newSrcType = unPackOp.getSourceType().clone(srcShape); - source = rewriter.create(loc, newSrcType, + source = tensor::CastOp::create(rewriter, loc, newSrcType, unPackOp.getSource()); } Value dest = unPackOp.getDest(); if (destShape != unPackOp.getDestType().getShape()) { auto newDestType = unPackOp.getDestType().clone(destShape); dest = - rewriter.create(loc, newDestType, unPackOp.getDest()); + tensor::CastOp::create(rewriter, loc, newDestType, unPackOp.getDest()); } - Value newOp = rewriter.create( + Value newOp = UnPackOp::create(rewriter, loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm()); rewriter.replaceOpWithNewOp( @@ -5542,7 +5542,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { // TODO: Strictly speaking, discardable attributes should be _discarded_ at // this point. However, in practice, we use them for things that we'd like // to preserve. Implement a better abstraction. - UnPackOp newOp = rewriter.create( + UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getOuterDimsPerm()); newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); @@ -5551,7 +5551,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { Value oldResult = op.getResult(); Value newResult = newOp.getResult(); Value replacement = (newResult.getType() != oldResult.getType()) - ? rewriter.create( + ? tensor::CastOp::create(rewriter, op->getLoc(), oldResult.getType(), newResult) : newResult; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 5d5f9de465561..ba03d09e5072e 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -672,7 +672,7 @@ static Operation *replaceForAllWithNewSignature( newOuts.push_back(outputs[resultNumber]); // Create new scf.forall op - auto newforallOp = rewriter.create( + auto newforallOp = scf::ForallOp::create(rewriter, loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, forallOp.getMapping()); rewriter.eraseBlock(newforallOp.getBody()); @@ -699,7 +699,7 @@ static Operation *replaceForAllWithNewSignature( Value src = tileAndFuseResult.tiledValues[0]; Value dst = newforallOp.getRegionIterArgs().back(); SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); - rewriter.create(firstYieldOp->getLoc(), src, + tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src, dst, offsets, sizes, strides); for (auto result : llvm::enumerate(forallOp.getResults())) { @@ -3410,12 +3410,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { if (auto attr = llvm::dyn_cast_if_present(ofr)) { if (scalableSizes[ofrIdx]) { - auto val = b.create( + auto val = arith::ConstantIndexOp::create(b, getLoc(), cast(attr).getInt()); Value vscale = - b.create(getLoc(), b.getIndexType()); + vector::VectorScaleOp::create(b, getLoc(), b.getIndexType()); sizes.push_back( - b.create(getLoc(), val, vscale).getResult()); + arith::MulIOp::create(b, getLoc(), val, vscale).getResult()); } else { sizes.push_back(attr); } @@ -3626,7 +3626,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, SmallVector normalizedSteps(normalizedUbs.size(), rewriter.getIndexAttr(1)); - auto normalizedForallOp = rewriter.create( + auto normalizedForallOp = scf::ForallOp::create(rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(), loop.getMapping(), [](OpBuilder &, Location, ValueRange) {}); @@ -4128,7 +4128,7 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, target->template getParentOfType()); } - Value extracted = rewriter.create( + Value extracted = tensor::ExtractSliceOp::create(rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); Value copied = rewriter diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 1f6d96ca0f81f..2cb8ad367f93c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -184,7 +184,7 @@ struct SoftmaxOpInterface getBuffer(rewriter, softmaxOp.getOutput(), options, state); if (failed(outputBuffer)) return failure(); - rewriter.create(softmaxOp.getLoc(), + linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(), /*result=*/TypeRange(), *inputBuffer, *outputBuffer, softmaxOp.getDimension()); replaceOpWithBufferizedValues(rewriter, op, *outputBuffer); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index a7732b939e70d..61c779850366f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -30,10 +30,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { if (isa(x.getType())) - return builder.create(loc, x, y); + return arith::AddIOp::create(builder, loc, x, y); if (isa(x.getType())) - return builder.create(loc, x, y); - return builder.create(loc, x, y); + return complex::AddOp::create(builder, loc, x, y); + return arith::AddFOp::create(builder, loc, x, y); } static Value createMul(Location loc, Value x, Value y, Type accType, @@ -44,10 +44,10 @@ static Value createMul(Location loc, Value x, Value y, Type accType, Value yConvert = convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); if (isa(accType)) - return builder.create(loc, xConvert, yConvert); + return complex::MulOp::create(builder, loc, xConvert, yConvert); if (isa(accType)) - return builder.create(loc, xConvert, yConvert); - return builder.create(loc, xConvert, yConvert); + return arith::MulIOp::create(builder, loc, xConvert, yConvert); + return arith::MulFOp::create(builder, loc, xConvert, yConvert); } // Delinearizes the given composite `index` by the basis specified in `factors`. @@ -56,7 +56,7 @@ static SmallVector unrollIndex(OpBuilder &b, Location loc, Value index, assert(!factors.empty() && "empty factor list"); SmallVector basis; for (int64_t f : factors) - basis.push_back(b.create(loc, b.getIndexAttr(f))); + basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f))); FailureOr> multiIndex = affine::delinearizeIndex(b, loc, index, basis); assert(!failed(multiIndex) && "Failed to linearize img2col index"); @@ -115,17 +115,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; auto reshapedFilterType = RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType()); - Value reshapedFilter = rewriter.create( + Value reshapedFilter = tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterType, filter, filterReassocIndices); SmallVector outputReassocIndices = {{0}, {1, 2}, {3}}; RankedTensorType reshapedOutputType = RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); - Value reshapedOutput = rewriter.create( + Value reshapedOutput = tensor::CollapseShapeOp::create(rewriter, loc, reshapedOutputType, output, outputReassocIndices); SmallVector colTensorShape = {n, oh * ow, fh * fw * ic}; - Value colTensor = rewriter.create( + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); // Convert the input to a (BMK) column tensor. @@ -138,15 +138,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { SmallVector img2colIndexingMaps = { AffineMap::getMultiDimIdentityMap(nloops, context)}; - auto img2ColTensor = rewriter.create( + auto img2ColTensor = linalg::GenericOp::create(rewriter, loc, colTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = nestedBuilder.create(loc, 0); - Value mIndex = nestedBuilder.create(loc, 1); - Value kIndex = nestedBuilder.create(loc, 2); + Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); + Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); // Recover the original iteration indices from the problem/input sizes. SmallVector mIndices = unrollIndex( @@ -170,9 +170,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = nestedBuilder.create( + Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, extractionIndices); - nestedBuilder.create(nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); }); // Because the filter does not share the same batch dimension, @@ -187,7 +187,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { SmallVector genericIterators = {parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, reshapedOutputType, /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, /*outputs=*/ValueRange{reshapedOutput}, @@ -196,11 +196,11 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { Value mul = createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); + linalg::YieldOp::create(nestedBuilder, nestedLoc, add); }); Value result = genericOp.getResults().front(); - auto reshapedResult = rewriter.create( + auto reshapedResult = tensor::ExpandShapeOp::create(rewriter, loc, outputType, result, outputReassocIndices); rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); @@ -244,7 +244,7 @@ rewriteInIm2Col(RewriterBase &rewriter, SmallVector targetShape = llvm::to_vector<4>(llvm::map_range( indices, [&](int64_t index) -> int64_t { return inputShape[index]; })); - Value outputTensor = rewriter.create( + Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape, operandTensorType.getElementType()); SmallVector loopAttributeTypes( @@ -255,12 +255,12 @@ rewriteInIm2Col(RewriterBase &rewriter, AffineMap::get(nloops, 0, exprs, rewriter.getContext())), AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; - auto transposedOp = rewriter.create( + auto transposedOp = linalg::GenericOp::create(rewriter, loc, outputTensor.getType(), /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, loopAttributeTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); return transposedOp.getResult(0); @@ -307,15 +307,15 @@ rewriteInIm2Col(RewriterBase &rewriter, AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; - Value colTensor = rewriter.create( + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); - auto img2ColTensor = rewriter.create( + auto img2ColTensor = linalg::GenericOp::create(rewriter, loc, colTensor.getType(), /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps, loopAttributeTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); SmallVector img2ColTensorReassocIndices = { @@ -331,16 +331,16 @@ rewriteInIm2Col(RewriterBase &rewriter, auto reshapedOutputTensorType = RankedTensorType::get({n * c, oh * ow}, outputType.getElementType()); - Value reshapedImg2ColTensor = rewriter.create( + Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), img2ColTensorReassocIndices); - Value reshapedFilterTensor = rewriter.create( + Value reshapedFilterTensor = tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType, filterT, filterReassociationIndice); - Value reshapedoutputTensor = rewriter.create( + Value reshapedoutputTensor = tensor::CollapseShapeOp::create(rewriter, loc, reshapedOutputTensorType, transposedOutputTensor, outputReassociationIndice); - auto batchMatVecResult = rewriter.create( + auto batchMatVecResult = linalg::BatchMatvecOp::create(rewriter, loc, TypeRange{reshapedoutputTensor.getType()}, ValueRange{reshapedImg2ColTensor, reshapedFilterTensor}, ValueRange{reshapedoutputTensor}); @@ -348,7 +348,7 @@ rewriteInIm2Col(RewriterBase &rewriter, SmallVector batchMatVecReassociationIndice = {{0, 1}, {2, 3}}; - auto batchMatVecResultReshaped = rewriter.create( + auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(rewriter, loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), batchMatVecReassociationIndice); @@ -400,18 +400,18 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; auto reshapedFilterType = RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType()); - Value reshapedFilter = rewriter.create( + Value reshapedFilter = tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterType, filter, filterReassocIndices); SmallVector outputReassocIndices = {{0}, {1}, {2, 3}}; auto reshapedOutputType = RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); - Value reshapedOutput = rewriter.create( + Value reshapedOutput = tensor::CollapseShapeOp::create(rewriter, loc, reshapedOutputType, output, outputReassocIndices); // Convert the input to a (BKN) tensor. SmallVector colTensorShape = {n, ic * fh * fw, oh * ow}; - Value colTensor = rewriter.create( + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); auto nloops = colTensorShape.size(); @@ -423,15 +423,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { SmallVector img2colIndexingMaps = { AffineMap::getMultiDimIdentityMap(nloops, context)}; - auto img2ColTensor = rewriter.create( + auto img2ColTensor = linalg::GenericOp::create(rewriter, loc, colTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = nestedBuilder.create(loc, 0); - Value kIndex = nestedBuilder.create(loc, 1); - Value nIndex = nestedBuilder.create(loc, 2); + Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); + Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); // Recover the original iteration indices from the problem/input sizes. SmallVector kIndices = unrollIndex( @@ -455,9 +455,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] SmallVector extractionIndices{bIndex, icIndex, hIndex, wIndex}; - Value inputVal = nestedBuilder.create( + Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, extractionIndices); - nestedBuilder.create(nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); }); // Because the filter does not share the same batch dimension, @@ -471,7 +471,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); SmallVector genericIterators = {parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, reshapedOutputType, /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)}, /*outputs=*/ValueRange{reshapedOutput}, @@ -480,11 +480,11 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { Value mul = createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); + linalg::YieldOp::create(nestedBuilder, nestedLoc, add); }); Value result = genericOp.getResults().front(); - auto reshapedResult = rewriter.create( + auto reshapedResult = tensor::ExpandShapeOp::create(rewriter, loc, outputType, result, outputReassocIndices); rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); @@ -535,17 +535,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; auto reshapedFilterType = RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType()); - Value reshapedFilter = rewriter.create( + Value reshapedFilter = tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterType, filter, filterReassocIndices); SmallVector outputReassocIndices = {{0}, {1, 2}, {3}}; RankedTensorType reshapedOutputType = RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); - Value reshapedOutput = rewriter.create( + Value reshapedOutput = tensor::CollapseShapeOp::create(rewriter, loc, reshapedOutputType, output, outputReassocIndices); SmallVector colTensorShape = {n, oh * ow, fh * fw * ic}; - Value colTensor = rewriter.create( + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); // Convert the input to a (BMK) column tensor. @@ -558,15 +558,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { SmallVector img2colIndexingMaps = { AffineMap::getMultiDimIdentityMap(nloops, context)}; - auto img2ColTensor = rewriter.create( + auto img2ColTensor = linalg::GenericOp::create(rewriter, loc, colTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = nestedBuilder.create(loc, 0); - Value mIndex = nestedBuilder.create(loc, 1); - Value kIndex = nestedBuilder.create(loc, 2); + Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); + Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); // Recover the original iteration indices from the problem/input sizes. SmallVector mIndices = unrollIndex( @@ -590,9 +590,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = nestedBuilder.create( + Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, extractionIndices); - nestedBuilder.create(nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); }); // Because we didn't transpose the filters we don't actually have a batched @@ -606,7 +606,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { SmallVector genericIterators = {parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, reshapedOutputType, /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, /*outputs=*/ValueRange{reshapedOutput}, @@ -615,11 +615,11 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { Value mul = createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); + linalg::YieldOp::create(nestedBuilder, nestedLoc, add); }); Value result = genericOp.getResults().front(); - auto reshapedResult = rewriter.create( + auto reshapedResult = tensor::ExpandShapeOp::create(rewriter, loc, outputType, result, outputReassocIndices); rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 39e2aac27e213..f21563c323c1b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -37,7 +37,7 @@ static Value createInserts(RewriterBase &rewriter, Location loc, int dim, if (dim == static_cast(shape.size()) - 1) { for (int i = 0; i < shape.back(); ++i) { indices.back() = constants[i]; - destination = rewriter.create(loc, *elementIt, + destination = tensor::InsertOp::create(rewriter, loc, *elementIt, destination, indices); ++elementIt; } @@ -65,7 +65,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, MaterializeInDestination: { // Note: This is the preferred way of memcpy'ing because no layout map // and/or memory space must be specified for the source. - auto materializeOp = b.create( + auto materializeOp = bufferization::MaterializeInDestinationOp::create(b, loc, tensorSource, memrefDest); materializeOp.setWritable(true); } break; @@ -73,19 +73,19 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, // TODO: Support custom memory space on source. // We do not know the layout map of the source yet, so use a fully dynamic // layout for best compatibility. - Value toBuffer = b.create( + Value toBuffer = bufferization::ToBufferOp::create(b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), tensorSource, /*readOnly=*/true); - b.create(loc, toBuffer, memrefDest); + memref::CopyOp::create(b, loc, toBuffer, memrefDest); } break; case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: { // TODO: Support custom memory space on source. // We do not know the layout map of the source yet, so use a fully dynamic // layout for best compatibility. - Value toBuffer = b.create( + Value toBuffer = bufferization::ToBufferOp::create(b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), tensorSource, /*readOnly=*/true); - b.create(loc, toBuffer, memrefDest); + linalg::CopyOp::create(b, loc, toBuffer, memrefDest); } break; }; } @@ -120,14 +120,14 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, ->materializeConstant(rewriter, constYieldedValue, yieldedValue.getType(), yieldedValue.getLoc()) ->getResult(0); - auto fillOp = rewriter.create(loc, ValueRange(fillValue), + auto fillOp = linalg::FillOp::create(rewriter, loc, ValueRange(fillValue), ValueRange(dest)); return fillOp; } if (invariantYieldedValue) { // Padding with an invariant value. - auto fillOp = rewriter.create(loc, ValueRange(yieldedValue), + auto fillOp = linalg::FillOp::create(rewriter, loc, ValueRange(yieldedValue), ValueRange(dest)); return fillOp; } @@ -137,7 +137,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, utils::IteratorType::parallel); SmallVector indexingMaps( 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, resultType, /*inputs=*/ValueRange(), /*outputs=*/ValueRange{dest}, /*indexingMaps=*/ indexingMaps, iteratorTypes); @@ -146,7 +146,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, rewriter.setInsertionPointToStart(body); SmallVector bbArgReplacements; for (int64_t i = 0; i < resultType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); + bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i)); rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); // Update terminator. @@ -179,8 +179,8 @@ static SmallVector reifyOrComputeDynamicSizes(OpBuilder &b, for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) dynSizes.push_back( - b.create(value.getLoc(), value, - b.create(value.getLoc(), i))); + DimOp::create(b, value.getLoc(), value, + arith::ConstantIndexOp::create(b, value.getLoc(), i))); } return dynSizes; } @@ -201,15 +201,15 @@ createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, Value alloc; if (options.allocOp == linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) { - alloc = rewriter.create(loc, memrefType, dynamicSizes); + alloc = memref::AllocOp::create(rewriter, loc, memrefType, dynamicSizes); if (options.emitDealloc) { // Place deallocation at the end of the block. rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); - rewriter.create(loc, alloc); + memref::DeallocOp::create(rewriter, loc, alloc); } } else if (options.allocOp == linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) { - alloc = rewriter.create(loc, memrefType, dynamicSizes); + alloc = memref::AllocaOp::create(rewriter, loc, memrefType, dynamicSizes); // No dealloc is needed. } @@ -243,13 +243,13 @@ Value linalg::bufferizeToAllocation( getMixedSizes(rewriter, loc, padOp.getSource()); SmallVector strides(padOp.getResultType().getRank(), rewriter.getIndexAttr(1)); - Value subview = rewriter.create( + Value subview = memref::SubViewOp::create(rewriter, loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); createMemcpy(rewriter, loc, padOp.getSource(), subview, options); // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. - Value toTensorOp = rewriter.create( + Value toTensorOp = bufferization::ToTensorOp::create(rewriter, loc, padOp.getResult().getType(), alloc, /*restrict=*/true, /*writable=*/true); rewriter.replaceOp(padOp, toTensorOp); @@ -338,7 +338,7 @@ Value linalg::bufferizeToAllocation( // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. - Value toTensorOp = rewriter.create( + Value toTensorOp = bufferization::ToTensorOp::create(rewriter, loc, allocTensorOp.getResult().getType(), alloc, /*restrict=*/true, /*writable=*/true); rewriter.replaceOp(allocTensorOp, toTensorOp); @@ -354,7 +354,7 @@ FailureOr mlir::linalg::rewriteInDestinationPassingStyle( auto shape = tensorType.getShape(); // Create tensor.empty. - auto emptyOp = rewriter.create(loc, tensorType, ValueRange()); + auto emptyOp = EmptyOp::create(rewriter, loc, tensorType, ValueRange()); // Case: tensor. if (shape.empty()) { @@ -369,7 +369,7 @@ FailureOr mlir::linalg::rewriteInDestinationPassingStyle( SmallVector constants; constants.reserve(maxDim); for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create(loc, i)); + constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i)); // Traverse all elements and create tensor.insert ops. auto elementIt = fromElementsOp.getElements().begin(); @@ -395,14 +395,14 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, // Create tensor.empty. auto emptyOp = - rewriter.create(loc, tensorType, generateOp.getDynamicExtents()); + EmptyOp::create(rewriter, loc, tensorType, generateOp.getDynamicExtents()); // Create linalg.generic. SmallVector iteratorTypes(tensorType.getRank(), utils::IteratorType::parallel); SmallVector indexingMaps( 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); - auto genericOp = rewriter.create( + auto genericOp = linalg::GenericOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(), /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ indexingMaps, iteratorTypes); @@ -411,7 +411,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, rewriter.setInsertionPointToStart(body); SmallVector bbArgReplacements; for (int64_t i = 0; i < tensorType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); + bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i)); rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); // Update terminator. @@ -450,13 +450,13 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) { using bufferization::AllocTensorOp; Value allocated = - rewriter.create(loc, resultType, dynamicSizes); + AllocTensorOp::create(rewriter, loc, resultType, dynamicSizes); auto copyOp = rewriter.replaceOpWithNewOp( padOp, padOp.getSource(), allocated); return copyOp.getOperation(); } - Value empty = rewriter.create(loc, resultType, dynamicSizes); + Value empty = EmptyOp::create(rewriter, loc, resultType, dynamicSizes); // Create linalg.fill or linalg.generic. Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty); rewriter.setInsertionPointAfter(fillOp); @@ -567,7 +567,7 @@ Value linalg::bufferizeToAllocation( createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); } rewriter.modifyOpInPlace(op, [&]() { - auto toTensorOp = rewriter.create( + auto toTensorOp = ToTensorOp::create(rewriter, op->getLoc(), operand->get().getType(), alloc); operand->set(toTensorOp); if (options.bufferizeDestinationOnly) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 70574903f7111..eb2c7756dc6ae 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -287,7 +287,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); - auto packedOperand = b.create( + auto packedOperand = linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, /*padding=*/std::nullopt, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); @@ -345,7 +345,7 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, indexingMaps.push_back(packedOutIndexingMap); - auto newGenericOp = rewriter.create( + auto newGenericOp = linalg::GenericOp::create(rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), @@ -457,7 +457,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, if (!packOpDest.hasOneUse()) return failure(); if (auto emptyOp = packOpDest.getDefiningOp()) { - packOpDest = rewriter.create( + packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(), emptyOp.getMixedSizes(), emptyOp.getType().getElementType()); } else { @@ -562,7 +562,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { auto empty = linalg::PackOp::createDestinationTensor( rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, outerDimsPerm); - auto sourcePack = rewriter.create( + auto sourcePack = linalg::PackOp::create(rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, /*padding=*/std::nullopt, outerDimsPerm); @@ -579,7 +579,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); - auto newPadOp = rewriter.create( + auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, padOp.getNofold()); @@ -588,7 +588,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { if (!padOp->hasOneUse()) { auto unpackEmpty = linalg::UnPackOp::createDestinationTensor( rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); - Value unpackedPad = rewriter.create( + Value unpackedPad = linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); } @@ -719,7 +719,7 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, auto emptyOp = linalg::PackOp::createDestinationTensor( rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), projectedInnerDimsPos, newOuterDimsPerm); - auto newPackOp = rewriter.create( + auto newPackOp = linalg::PackOp::create(rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); @@ -735,7 +735,7 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, nextPos += 1; } - auto newCollapseOp = rewriter.create( + auto newCollapseOp = tensor::CollapseShapeOp::create(rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); rewriter.replaceOp(packOp, newCollapseOp); @@ -853,12 +853,12 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, Value destTensor = linalg::PackOp::createDestinationTensor( rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector{}); - Value packedVal = rewriter.create( + Value packedVal = linalg::PackOp::create(rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(), /*outerDimsPerm=*/SmallVector{}); - Value newExpandOp = rewriter.create( + Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); rewriter.replaceOp(packOp, newExpandOp); @@ -972,14 +972,14 @@ static LogicalResult pushDownUnPackOpThroughExpandShape( RankedTensorType newExpandType = linalg::PackOp::inferPackedType( expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); - auto newExpandOp = rewriter.create( + auto newExpandOp = tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType, unPackOp.getSource(), newReassocIndices); auto emptyOp = linalg::UnPackOp::createDestinationTensor( rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), projectedInnerDimsPos, newOuterDimsPerm); - auto newUnPackOp = rewriter.create( + auto newUnPackOp = linalg::UnPackOp::create(rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); rewriter.replaceOp(expandOp, newUnPackOp); @@ -1212,16 +1212,16 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); - auto newPadOp = rewriter.create( + auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, paddingVal, padOp.getNofold()); // Inject the linalg.unpack right after the packed padOp. - Value outputUnPack = rewriter.create( + Value outputUnPack = tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(), padOp.getResultType().getElementType()); - Value replacement = rewriter.create( + Value replacement = linalg::UnPackOp::create(rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos, unpackOp.getMixedTiles(), outerDimsPerm); rewriter.replaceOp(padOp, replacement); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp index 692bf595267d4..82d386b3b9099 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -198,9 +198,9 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( transposedShape[i] = inputRTType.getShape()[permutation[i]]; Value emptyTensor = - rewriter.create(loc, transposedShape, elType); + tensor::EmptyOp::create(rewriter, loc, transposedShape, elType); - auto transposeOp = rewriter.create(loc, newInitValues[i], + auto transposeOp = TransposeOp::create(rewriter, loc, newInitValues[i], emptyTensor, permutation); newInitValues[i] = transposeOp->getResult(0); isChanged = true; @@ -209,10 +209,10 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( // Does it require broadcast? if (!broadcastedDims.empty()) { assert(broadcastedDims.size() && "should have non size broadcast"); - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape, inputRTType.getElementType()); - auto broadcastOp = rewriter.create( + auto broadcastOp = linalg::BroadcastOp::create(rewriter, loc, newInitValues[i], emptyTensor, broadcastedDims); newInitValues[i] = broadcastOp->getResult(0); @@ -227,7 +227,7 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( SmallVector operands = op->getOperands(); ValueRange operandsRef(operands); - auto newOp = rewriter.create( + auto newOp = linalg::GenericOp::create(rewriter, /*location=*/op.getLoc(), /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/newInitValues, diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index 1419175304899..bd3fad01aa48b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -133,12 +133,12 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) { assert(elementType.isIntOrIndexOrFloat() && "expected scalar type while computing zero value"); if (isa(elementType)) - return b.create(loc, elementType, 0); + return arith::ConstantIntOp::create(b, loc, elementType, 0); if (elementType.isIndex()) - return b.create(loc, 0); + return arith::ConstantIndexOp::create(b, loc, 0); // Assume float. auto floatType = cast(elementType); - return b.create( + return arith::ConstantFloatOp::create(b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics())); } @@ -189,7 +189,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, // Fall back path, use an `init_tensor` and identity indexing map. AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); Value emptyTensor = - rewriter.create(loc, domain, scalarOpResult.getType()); + tensor::EmptyOp::create(rewriter, loc, domain, scalarOpResult.getType()); newInitValues.push_back(emptyTensor); newResultTypes.push_back(emptyTensor.getType()); peeledGenericOpIndexingMaps.push_back(indexingMap); @@ -202,7 +202,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, resultTypes.append(newResultTypes.begin(), newResultTypes.end()); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); - return rewriter.create( + return GenericOp::create(rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {}); @@ -239,7 +239,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); - return rewriter.create( + return GenericOp::create(rewriter, genericOp->getLoc(), genericOp->getResultTypes(), residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, @@ -324,7 +324,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, yieldedVals.append(llvm::to_vector( llvm::map_range(peeledScalarOperation->getResults(), [](OpResult opr) -> Value { return opr; }))); - rewriter.create(genericOp.getLoc(), yieldedVals); + YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals); } /// In the split operations, replace block arguments uses that refer to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index ef24eb881d68b..b8e73bd5929dd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -34,7 +34,7 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type, // A detensored value is converted back by creating a new tensor from its // element(s). - return builder.create( + return tensor::FromElementsOp::create(builder, loc, RankedTensorType::get({}, inputType), inputs[0]); } @@ -147,7 +147,7 @@ class DetensorizeTypeConverter : public TypeConverter { // A tensor value is detensoried by extracting its element(s). addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, inputs[0], ValueRange{}); + return tensor::ExtractOp::create(builder, loc, inputs[0], ValueRange{}); }); addSourceMaterialization(sourceMaterializationCallback); @@ -481,7 +481,7 @@ struct LinalgDetensorize rewriter.splitBlock(entryBlock, entryBlock->begin()); rewriter.setInsertionPointToStart(entryBlock); auto branch = - rewriter.create(rewriter.getUnknownLoc(), postEntryBlock); + cf::BranchOp::create(rewriter, rewriter.getUnknownLoc(), postEntryBlock); if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index e0062d15e61ca..fd8c177cc7247 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -118,14 +118,14 @@ struct MoveInitOperandsToInput : public OpRewritePattern { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); auto elemType = cast(op->get().getType()).getElementType(); - auto empty = rewriter.create( + auto empty = tensor::EmptyOp::create(rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType); unsigned start = genericOp.getDpsInits().getBeginOperandIndex(); newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); } - auto newOp = rewriter.create( + auto newOp = GenericOp::create(rewriter, loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); @@ -295,7 +295,7 @@ static Value collapseValue( MemRefLayoutAttrInterface layout; auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), layout, memrefType.getMemorySpace()); - return rewriter.create(loc, targetType, operand, + return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand, reassociation); } if (auto tensorType = dyn_cast(operand.getType())) { @@ -314,7 +314,7 @@ static Value collapseValue( "unknown rank reduction strategy"); auto targetType = RankedTensorType::get(targetShape, tensorType.getElementType()); - return rewriter.create(loc, targetType, operand, + return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand, reassociation); } llvm_unreachable("unsupported operand type"); @@ -519,7 +519,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, for (unsigned i : llvm::seq(0, genericOp.getNumResults())) resultTypes.push_back(newOutputs[i].getType()); GenericOp replacementOp = - rewriter.create(loc, resultTypes, newInputs, newOutputs, + GenericOp::create(rewriter, loc, resultTypes, newInputs, newOutputs, newIndexingMaps, newIteratorTypes); rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), replacementOp.getRegion().begin()); @@ -652,7 +652,7 @@ struct DropPadUnitDims : public OpRewritePattern { collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); - auto newPadOp = rewriter.create( + auto newPadOp = tensor::PadOp::create(rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); @@ -670,7 +670,7 @@ struct DropPadUnitDims : public OpRewritePattern { expandedSizes.push_back(tensor::getMixedSize( rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims)); } - dest = rewriter.create( + dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes, padOp.getResultType().getElementType()); } @@ -713,7 +713,7 @@ struct RankReducedExtractSliceOp strides)); Location loc = sliceOp.getLoc(); - Value newSlice = rewriter.create( + Value newSlice = tensor::ExtractSliceOp::create(rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); rewriter.replaceOpWithNewOp( sliceOp, resultType, newSlice, *reassociation); @@ -747,7 +747,7 @@ struct RankReducedInsertSliceOp : public OpRewritePattern { // parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); - reshapedSource = rewriter.create( + reshapedSource = tensor::CollapseShapeOp::create(rewriter, loc, insertSliceOp.getSource(), *reassociation); } rewriter.replaceOpWithNewOp( @@ -898,7 +898,7 @@ struct RankReduceContractionOps : OpRewritePattern { /// Expand result tensor. Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType, int64_t dim) const { - return rewriter.create( + return tensor::ExpandShapeOp::create(rewriter, result.getLoc(), expandedType, result, getReassociationForReshapeAtDim(expandedType.getRank(), dim)); } @@ -934,7 +934,7 @@ struct RankReduceContractionOps : OpRewritePattern { SmallVector collapsedResultTy; if (isa(collapsedInit.getType())) collapsedResultTy.push_back(collapsedInit.getType()); - auto collapsedOp = rewriter.create( + auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index f97ed3d6d5111..1ae4375bd09eb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -236,11 +236,11 @@ static void generateFusedElementwiseOpRegion( fusedIndices.reserve(numFusedOpLoops); llvm::transform(llvm::seq(0, numFusedOpLoops), std::back_inserter(fusedIndices), [&](uint64_t dim) { - return rewriter.create(producer.getLoc(), dim); + return IndexOp::create(rewriter, producer.getLoc(), dim); }); for (IndexOp indexOp : llvm::make_early_inc_range(producerBlock.getOps())) { - Value newIndex = rewriter.create( + Value newIndex = affine::AffineApplyOp::create(rewriter, producer.getLoc(), consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices); mapper.map(indexOp.getResult(), newIndex); @@ -327,7 +327,7 @@ static void generateFusedElementwiseOpRegion( } for (auto consumerYieldVal : consumerYieldOp.getOperands()) fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); - rewriter.create(fusedOp.getLoc(), fusedYieldValues); + YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues); // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && @@ -416,7 +416,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // Generate the fused op. - auto fusedOp = rewriter.create( + auto fusedOp = GenericOp::create(rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands, fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.getIteratorTypes(), @@ -750,9 +750,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, expandedIndices.reserve(expandedDims.size() - 1); llvm::transform( expandedDims.drop_front(), std::back_inserter(expandedIndices), - [&](int64_t dim) { return rewriter.create(loc, dim); }); + [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); }); OpFoldResult newIndex = - rewriter.create(loc, expandedDims.front()).getResult(); + IndexOp::create(rewriter, loc, expandedDims.front()).getResult(); for (auto [expandedShape, expandedIndex] : llvm::zip(expandedDimsShape, expandedIndices)) { AffineExpr idx, acc, shape; @@ -796,7 +796,7 @@ static Operation *createExpandedTransposeOp(PatternRewriter &rewriter, newPerm.push_back(dim); } } - return rewriter.create(transposeOp.getLoc(), expandedInput, + return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput, output, invertPermutationVector(newPerm)); } @@ -813,7 +813,7 @@ static Operation *createExpandedGenericOp( for (auto j : expansionInfo.getExpandedDims(i)) iteratorTypes[j] = type; - Operation *fused = rewriter.create( + Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs, expandedOpIndexingMaps, iteratorTypes); @@ -933,7 +933,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, reassociation, /*isExpandingReshape=*/true))) return std::nullopt; - expandedOpOperands.push_back(rewriter.create( + expandedOpOperands.push_back(tensor::ExpandShapeOp::create(rewriter, loc, expandedOperandType, opOperand->get(), reassociation, expandedOperandShape)); continue; @@ -961,7 +961,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, reassociation, /*isExpandingReshape=*/true))) return std::nullopt; - outputs.push_back(rewriter.create( + outputs.push_back(tensor::ExpandShapeOp::create(rewriter, loc, expandedOutputType, opOperand.get(), reassociation, expandedOutputShape)); } else { @@ -984,7 +984,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, linalgOp.getMatchingIndexingMap( linalgOp.getDpsInitOperand(resultNumber)), expansionInfo); - resultVals.push_back(rewriter.create( + resultVals.push_back(tensor::CollapseShapeOp::create(rewriter, linalgOp.getLoc(), opResult.getType(), fusedOp->getResult(resultNumber), reassociation)); } else { @@ -1086,7 +1086,7 @@ class FoldPadWithProducerReshapeOpByExpansion Location loc = padOp->getLoc(); RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); - auto newPadOp = rewriter.create( + auto newPadOp = tensor::PadOp::create(rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, padOp.getConstantPaddingValue(), padOp.getNofold()); @@ -1603,7 +1603,7 @@ static void generateCollapsedIndexingRegion( enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { ReassociationIndicesRef foldedDimsRef(foldedDims.value()); Value newIndexVal = - rewriter.create(loc, foldedDims.index()); + linalg::IndexOp::create(rewriter, loc, foldedDims.index()); for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { Value loopDim = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]); @@ -1687,7 +1687,7 @@ GenericOp cloneToCollapsedOp(RewriterBase &rewriter, SmallVector iteratorTypes(getCollapsedOpIteratorTypes( origOp.getIteratorTypesArray(), collapsingInfo)); - GenericOp collapsedOp = rewriter.create( + GenericOp collapsedOp = linalg::GenericOp::create(rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); Block *origOpBlock = &origOp->getRegion(0).front(); @@ -1790,11 +1790,11 @@ FailureOr mlir::linalg::collapseOpIterationDims( if (isa(collapsedOpResult.getType())) { MemRefType expandShapeResultType = MemRefType::get( originalResultType.getShape(), originalResultType.getElementType()); - result = rewriter.create( + result = memref::ExpandShapeOp::create(rewriter, loc, expandShapeResultType, collapsedOpResult, reassociation, resultShape); } else { - result = rewriter.create( + result = tensor::ExpandShapeOp::create(rewriter, loc, originalResultType, collapsedOpResult, reassociation, resultShape); } @@ -1978,7 +1978,7 @@ class FoldPadWithProducerReshapeOpByCollapsing RankedTensorType collapsedPaddedType = paddedType.clone(collapsedPaddedShape); - auto newPadOp = rewriter.create( + auto newPadOp = tensor::PadOp::create(rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, padOp.getConstantPaddingValue(), padOp.getNofold()); @@ -2113,10 +2113,10 @@ class FoldScalarOrSplatConstant : public OpRewritePattern { // Create a constant scalar value from the splat constant. Value scalarConstant = - rewriter.create(def->getLoc(), constantAttr); + arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr); SmallVector outputOperands = genericOp.getOutputs(); - auto fusedOp = rewriter.create( + auto fusedOp = GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), /*inputs=*/fusedOperands, /*outputs=*/outputOperands, @@ -2179,7 +2179,7 @@ struct RemoveOutsDependency : public OpRewritePattern { modifiedOutput = true; SmallVector mixedSizes = tensor::getMixedSizes(rewriter, loc, operandVal); - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, mixedSizes, operandType.getElementType()); op->setOperand(opOperand.getOperandNumber(), emptyTensor); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index c4af09ca01421..d93f8c26a9f0a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -64,7 +64,7 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { continue; // Extract static / dynamic shape mix from the first operand. - res.push_back(b.create( + res.push_back(tensor::EmptyOp::create(b, loc, tensor::getMixedSizes(b, loc, operands.front()), cast(t).getElementType())); } @@ -104,7 +104,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index d375878fb2c91..4e95aae86e670 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -259,7 +259,7 @@ mlir::linalg::deduplicateOperandsAndRemoveDeadResults( for (Value v : newOutputOperands) if (isa(v.getType())) newResultTypes.push_back(v.getType()); - auto newOp = rewriter.create( + auto newOp = GenericOp::create(rewriter, loc, newResultTypes, newInputOperands, newOutputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.getIteratorTypes(), genericOp.getDocAttr(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp index 44469bc404a7c..e19e224c25028 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -72,14 +72,14 @@ struct FusePadOp : OpRewritePattern { // Create the tensor of same size as output of the pad op. RankedTensorType padResultType = padOp.getResultType(); auto resultSizes = resultShape[0]; - auto emptyTensor = rewriter.create( + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, resultSizes, padResultType.getElementType()); // Fill the tensor with the pad value. // TODO: There is an option to fill only the boundaries. For now just // filling the whole tensor. auto fillTensor = - rewriter.create(loc, padValue, emptyTensor.getResult()); + linalg::FillOp::create(rewriter, loc, padValue, emptyTensor.getResult()); // Construct a slice of the fill result that is to be replaced with the // result of the generic op. The low pad values are the offsets, the size of @@ -93,14 +93,14 @@ struct FusePadOp : OpRewritePattern { llvm::enumerate(cast(source.getType()).getShape())) { if (ShapedType::isDynamic(shape.value())) { sizes.push_back( - rewriter.create(loc, source, shape.index()) + tensor::DimOp::create(rewriter, loc, source, shape.index()) .getResult()); } else { sizes.push_back(rewriter.getIndexAttr(shape.value())); } } SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); - auto slice = rewriter.create( + auto slice = tensor::ExtractSliceOp::create(rewriter, loc, fillTensor.getResult(0), offsets, sizes, strides); // Clone the generic op. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 9bc7be2623849..41252c68ffda9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -277,7 +277,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, // mismatches. Insert a `tensor.cast` op to propagate the transformation // invariant that types are compatible. if (consumerType != def.getType()) - def = b.create(fusedProducer.getLoc(), consumerType, def); + def = tensor::CastOp::create(b, fusedProducer.getLoc(), consumerType, def); consumerOpOperand.set(def); return FusionInfo{cast(producerOpResult.getOwner()), fusedProducer}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 05f2157b77aeb..7994af8834990 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -61,7 +61,7 @@ FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, // All named ops have a region attached that can be inlined. assert(linalgOp->getNumRegions() == 1 && "expect named op to have one region attached"); - GenericOp genericOp = rewriter.create( + GenericOp genericOp = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators); rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 94ed46442180c..86f24578376c1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -591,7 +591,7 @@ static FailureOr buildPackingLoopNestImpl( // Create a packing loop that takes `hoistedPackedTensor` as iteration // argument. - auto clonedForOp = rewriter.create( + auto clonedForOp = scf::ForOp::create(rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()), bvm.lookupOrDefault(forOp.getUpperBound()), bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); @@ -640,10 +640,10 @@ static FailureOr buildPackingLoopNestImpl( TransposeOp maybeTransposeOp; Value paddedTensor = bvm.lookup(opToHoist.getResult()); if (!transposeVector.empty()) { - Value outputTensor = rewriter.create( + Value outputTensor = tensor::ExtractSliceOp::create(rewriter, loc, transposedTensorType, hoistedPackedTensor, offsets, sizes, strides); - maybeTransposeOp = rewriter.create( + maybeTransposeOp = linalg::TransposeOp::create(rewriter, loc, paddedTensor, outputTensor, transposeVector); paddedTensor = maybeTransposeOp.getResult()[0]; } @@ -652,7 +652,7 @@ static FailureOr buildPackingLoopNestImpl( if (nPackedLoops > 0) { // Step 4. Create InsertSliceOp at the innermost loop level, inserting an // optionally transposed padded slice into the packed tensor. - Value inserted = rewriter.create( + Value inserted = tensor::InsertSliceOp::create(rewriter, loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides); // Step 5. Iteratively pop the stack and propagate the yield. @@ -660,7 +660,7 @@ static FailureOr buildPackingLoopNestImpl( for (Value iv : llvm::reverse(clonedLoopIvs)) { auto forOp = scf::getForInductionVarOwner(iv); rewriter.setInsertionPointToEnd(&forOp.getRegion().front()); - rewriter.create(loc, valueToYield); + scf::YieldOp::create(rewriter, loc, valueToYield); valueToYield = forOp.getResult(0); } } @@ -712,7 +712,7 @@ static FailureOr buildPackingLoopNestImpl( rewriter.setInsertionPoint(outerLoop); SmallVector dynamicTensorSizes = analysis.getHoistedPackedTensorSizes(rewriter, loc); - auto emptyOp = rewriter.create( + auto emptyOp = tensor::EmptyOp::create(rewriter, loc, hoistedPackedTensorType.getShape(), hoistedPackedTensorType.getElementType(), dynamicTensorSizes); @@ -840,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(forOp); - extracted = rewriter.create( + extracted = tensor::ExtractSliceOp::create(rewriter, hoistedPackedTensor.getLoc(), hoistedPackedTensor, outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(), outerSliceOp.getMixedStrides()); @@ -934,7 +934,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter, // offsets = [maybe_leading_ivs, 0 .. 0]. // sizes = [1 .. 1, transposedShape] (defined above). // strides = [1 .. 1] (defined above) - return rewriter.create( + return tensor::ExtractSliceOp::create(rewriter, loc, transposedTensorType, hoistedPackedTensor, offsets, packingResult.sizes, packingResult.strides); } @@ -982,9 +982,9 @@ FailureOr mlir::linalg::hoistPaddingOnTensors( OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newResult.getDefiningOp()); // Transpose the packed tensor back to the original storage order. - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, paddedTensorType.getShape(), paddedTensorType.getElementType()); - TransposeOp unTransposeOp = rewriter.create( + TransposeOp unTransposeOp = linalg::TransposeOp::create(rewriter, loc, newResult, emptyTensor, transposeVector); newResult = unTransposeOp.getResult()[0]; transposeOps.push_back(unTransposeOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index f2e51c29f3241..5397fc572e558 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -53,7 +53,7 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, assert(index < inits.size()); inits[index] = newInitOperand; - scf::ForOp newLoop = rewriter.create( + scf::ForOp newLoop = scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 1f3336d2bfbb9..a900215ccf014 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -60,7 +60,7 @@ struct InlineScalarOperands : public OpRewritePattern { Location loc = genericOp->getLoc(); SmallVector outputOperands = genericOp.getOutputs(); - auto newOp = rewriter.create( + auto newOp = GenericOp::create(rewriter, loc, genericOp->getResultTypes(), newOperands, outputOperands, newIndexingMaps, genericOp.getIteratorTypesArray()); rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(), @@ -77,11 +77,11 @@ struct InlineScalarOperands : public OpRewritePattern { SmallVector indicesValues; for (auto idx : indices) indicesValues.emplace_back( - rewriter.create(loc, idx)); + arith::ConstantIndexOp::create(rewriter, loc, idx)); Value scalarValue = opOperand->get(); if (isa(scalarValue.getType())) { scalarValue = - rewriter.create(loc, scalarValue, indicesValues); + tensor::ExtractOp::create(rewriter, loc, scalarValue, indicesValues); } body->getArgument(idx).replaceAllUsesWith(scalarValue); body->eraseArgument(idx); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index a92a0c83e0316..2372313e29c58 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -88,7 +88,7 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, allIndices.reserve(genericOp.getNumLoops()); llvm::transform(llvm::seq(0, genericOp.getNumLoops()), std::back_inserter(allIndices), [&](uint64_t dim) { - return rewriter.create(indexOp->getLoc(), dim); + return IndexOp::create(rewriter, indexOp->getLoc(), dim); }); rewriter.replaceOpWithNewOp( indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 488041a43a2ef..0c08c1d8b997d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -49,7 +49,7 @@ static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); SmallVector operands(vals); affine::canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(b.create(loc, exprMap, operands)); + res.push_back(affine::AffineApplyOp::create(b, loc, exprMap, operands)); } return res; } @@ -70,7 +70,7 @@ static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, Operation *terminator = block.getTerminator(); for (OpOperand &operand : terminator->getOpOperands()) { Value toStore = map.lookupOrDefault(operand.get()); - b.create(loc, toStore, outputBuffers[operand.getOperandNumber()], + StoreOpTy::create(b, loc, toStore, outputBuffers[operand.getOperandNumber()], indexing[operand.getOperandNumber()]); } } @@ -145,7 +145,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, auto indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); indexedValues.push_back( - b.create(loc, inputOperand->get(), indexing)); + LoadOpTy::create(b, loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { @@ -153,7 +153,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), allIvsPlusDims); indexedValues.push_back( - b.create(loc, outputOperand.get(), indexing)); + LoadOpTy::create(b, loc, outputOperand.get(), indexing)); } // TODO: When a region inliner exists, use it. diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp index ee1957aaa6a53..28b246714df53 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp @@ -131,17 +131,17 @@ static Value createDestinationPassingStyleInitOperand( ImplicitLocOpBuilder &builder) { Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( meshOp.getSymName(), reductionMeshAxes, builder); - Value zero = builder.create(0); - Value isLeadProcess = builder.create( + Value zero = arith::ConstantIndexOp::create(builder, 0); + Value isLeadProcess = arith::CmpIOp::create(builder, builder.getI1Type(), arith::CmpIPredicate::eq, processLinearIndexInReductionGroup, zero); - scf::IfOp ifOp = builder.create(spmdizedOperand.getType(), + scf::IfOp ifOp = scf::IfOp::create(builder, spmdizedOperand.getType(), isLeadProcess, true, true); // Then block. { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); - builder.create(spmdizedOperand); + scf::YieldOp::create(builder, spmdizedOperand); } // Else block. @@ -157,14 +157,14 @@ static Value createDestinationPassingStyleInitOperand( std::optional neutralEl = arith::getNeutralElement(combinerOps[0]); - Value init = builder.create(op.getLoc(), shape, + Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape, neutralEl.value().getType()); Value constant = - builder.create(op.getLoc(), neutralEl.value()); - Value fill = builder.create(op.getLoc(), constant, init) + arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value()); + Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init) .getResult(0); - builder.create(fill); + scf::YieldOp::create(builder, fill); } return ifOp.getResult(0); } @@ -202,7 +202,7 @@ static void createAllReduceForResultWithoutPartialSharding( } Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult); - Value reducedValue = builder.create( + Value reducedValue = mesh::AllReduceOp::create(builder, spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes, reductionKind); spmdizationMap.map(unshardedLinalgOpResult, reducedValue); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp index bb1e974391878..019c845b4e0e7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -59,7 +59,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, auto newKernelTy = RankedTensorType::get( {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, kernelTy.getElementType()); - auto collapsedKernel = rewriter.create( + auto collapsedKernel = tensor::CollapseShapeOp::create(rewriter, loc, newKernelTy, kernel, collapsedKernelDims); // Collapse init dims. @@ -70,7 +70,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), initTy.getDimSize(2), initTy.getDimSize(3)}, initTy.getElementType()); - auto collapsedInit = rewriter.create( + auto collapsedInit = tensor::CollapseShapeOp::create(rewriter, loc, newInitTy, init, collapsedInitDims); SmallVector preservedAttrs; @@ -78,13 +78,13 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, TypeSwitch(operation) .Case([&](auto op) { preservedAttrs = getPrunedAttributeList(op); - return rewriter.create( + return DepthwiseConv2DNhwcHwcOp::create(rewriter, loc, newInitTy, ValueRange{input, collapsedKernel}, ValueRange{collapsedInit}, stride, dilation); }) .Case([&](auto op) { preservedAttrs = getPrunedAttributeList(op); - return rewriter.create( + return DepthwiseConv2DNhwcHwcQOp::create(rewriter, loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, ValueRange{collapsedInit}, stride, dilation); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 2afa2f9b71c2a..1368bb7e77db0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -143,7 +143,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { Type newOperandType, ArrayAttr reassociation) const { if (operand.getType() == newOperandType) return operand; - return rewriter.create(loc, newOperandType, + return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType, operand, reassociation); } @@ -265,7 +265,7 @@ struct FoldUnpackWithExtractSliceOp // Create a new empty output tensor. Type elementType = unpackOp.getDestType().getElementType(); - Value output = rewriter.create( + Value output = tensor::EmptyOp::create(rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); rewriter.replaceOpWithNewOp( sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), @@ -529,7 +529,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp auto elemType = cast(unPackOp->getResultTypes()[0]).getElementType(); - Value output = rewriter.create( + Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(), unpackOpResultDims[0], elemType); rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 5eb3761f7aca1..fb43c1f92e5b3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -192,10 +192,10 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, if (auto complexTy = dyn_cast(getElementTypeOrSelf(v.getType()))) { auto complexAttr = cast(paddingValueAttr); - paddingValue = rewriter.create(opToPad.getLoc(), + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), complexTy, complexAttr); } else { - paddingValue = rewriter.create( + paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), cast(paddingValueAttr)); } @@ -323,7 +323,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, int64_t rank = cast(paddedResult.getType()).getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector strides(rank, rewriter.getIndexAttr(1)); - paddedSubtensorResults.push_back(rewriter.create( + paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index dc9e11eccac4d..af614e88fb3d5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -219,10 +219,10 @@ static FailureOr padOperandToSmallestStaticBoundingBox( if (auto complexTy = dyn_cast( getElementTypeOrSelf(opOperand->get().getType()))) { auto complexAttr = cast(paddingAttr); - paddingValue = rewriter.create(opToPad.getLoc(), + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), complexTy, complexAttr); } else { - paddingValue = rewriter.create( + paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), cast(paddingAttr)); } @@ -313,7 +313,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, int64_t rank = cast(paddedResult.getType()).getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector strides(rank, rewriter.getIndexAttr(1)); - paddedSubtensorResults.push_back(rewriter.create( + paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index b274502e16903..4042a43b04091 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -63,10 +63,10 @@ static Value allocBuffer(ImplicitLocOpBuilder &b, staticBufferType = MemRefType::Builder(staticBufferType).setMemorySpace(memorySpaceAttr); if (options.useAlloca) { - return b.create(staticBufferType, ValueRange{}, + return memref::AllocaOp::create(b, staticBufferType, ValueRange{}, alignmentAttr); } - return b.create(staticBufferType, ValueRange{}, + return memref::AllocOp::create(b, staticBufferType, ValueRange{}, alignmentAttr); } @@ -76,10 +76,10 @@ static Value allocBuffer(ImplicitLocOpBuilder &b, dynamicBufferType = MemRefType::Builder(dynamicBufferType).setMemorySpace(memorySpaceAttr); Value mul = b.createOrFold( - b.create(width), allocSize); + arith::ConstantIndexOp::create(b, width), allocSize); if (options.useAlloca) - return b.create(dynamicBufferType, mul, alignmentAttr); - return b.create(dynamicBufferType, mul, alignmentAttr); + return memref::AllocaOp::create(b, dynamicBufferType, mul, alignmentAttr); + return memref::AllocOp::create(b, dynamicBufferType, mul, alignmentAttr); } /// Default allocation callback function. This allocates a promoted buffer when @@ -92,8 +92,8 @@ static std::optional defaultAllocBufferCallBack( std::optional alignment, DataLayout &layout) { ShapedType viewType = subView.getType(); ImplicitLocOpBuilder b(subView.getLoc(), builder); - auto zero = b.create(0); - auto one = b.create(1); + auto zero = arith::ConstantIndexOp::create(b, 0); + auto one = arith::ConstantIndexOp::create(b, 1); Attribute memorySpaceAttr; if (options.memorySpace.has_value()) @@ -123,7 +123,7 @@ defaultDeallocBufferCallBack(const LinalgPromotionOptions &options, OpBuilder &b, Value fullLocalView) { if (!options.useAlloca) { auto viewOp = cast(fullLocalView.getDefiningOp()); - b.create(viewOp.getSource().getLoc(), + memref::DeallocOp::create(b, viewOp.getSource().getLoc(), viewOp.getSource()); } return success(); @@ -211,7 +211,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( Location loc = linalgOp.getLoc(); auto defaultCopyCallBack = [loc](OpBuilder &b, Value src, Value dst) -> LogicalResult { - b.create(loc, src, dst); + linalg::CopyOp::create(b, loc, src, dst); return success(); }; copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack); @@ -265,7 +265,7 @@ FailureOr mlir::linalg::promoteSubviewAsNewBuffer( /*stopCondition=*/nullptr, /*closedUB=*/true); size = failed(upperBound) ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) - : b.create(loc, *upperBound); + : arith::ConstantIndexOp::create(b, loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); fullSizes.push_back(size); @@ -310,23 +310,23 @@ promoteSubViews(ImplicitLocOpBuilder &b, Value fillVal = llvm::TypeSwitch(subviewEltType) .Case([&](FloatType t) { - return b.create(FloatAttr::get(t, 0.0)); + return arith::ConstantOp::create(b, FloatAttr::get(t, 0.0)); }) .Case([&](IntegerType t) { - return b.create(IntegerAttr::get(t, 0)); + return arith::ConstantOp::create(b, IntegerAttr::get(t, 0)); }) .Case([&](ComplexType t) { Value tmp; if (auto et = dyn_cast(t.getElementType())) - tmp = b.create(FloatAttr::get(et, 0.0)); + tmp = arith::ConstantOp::create(b, FloatAttr::get(et, 0.0)); else if (auto et = cast(t.getElementType())) - tmp = b.create(IntegerAttr::get(et, 0)); - return b.create(t, tmp, tmp); + tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0)); + return complex::CreateOp::create(b, t, tmp, tmp); }) .Default([](auto) { return Value(); }); if (!fillVal) return failure(); - b.create(fillVal, promotionInfo->fullLocalView); + linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); } // Copy data into the promoted buffers. Use callback if provided. @@ -459,9 +459,9 @@ static std::optional allocateSubviewGPUMemoryInAddressSpace( gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace)); Value buffer; if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) { - buffer = builder.create(funcOp.getLoc(), type); + buffer = memref::AllocOp::create(builder, funcOp.getLoc(), type); } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) { - buffer = builder.create(funcOp.getLoc(), type); + buffer = memref::AllocaOp::create(builder, funcOp.getLoc(), type); } else { return std::nullopt; } @@ -487,9 +487,9 @@ LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &, /// the copy operation to ensure data integrity. LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) { - b.create(src.getLoc()); - Operation *copyOp = b.create(src.getLoc(), src, dst); - b.create(copyOp->getLoc()); + gpu::BarrierOp::create(b, src.getLoc()); + Operation *copyOp = memref::CopyOp::create(b, src.getLoc(), src, dst); + gpu::BarrierOp::create(b, copyOp->getLoc()); return success(); } @@ -504,7 +504,7 @@ std::optional mlir::linalg::allocateGPUPrivateMemory( /// Normal copy to between src and dst. LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst) { - b.create(src.getLoc(), src, dst); + memref::CopyOp::create(b, src.getLoc(), src, dst); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index b30182dc84079..eac0e47b18a7d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -38,8 +38,8 @@ struct StructuredOpInterface SmallVector loopRanges = linalgOp.createLoopRanges(builder, loc); auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges); - auto zero = builder.create(loc, 0); - auto one = builder.create(loc, 1); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); + auto one = arith::ConstantIndexOp::create(builder, loc, 1); // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index 671dea8bb415f..1fac80142dddf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -52,7 +52,7 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op, return nullptr; SmallVector resultStrides(resultOffsets.size(), b.getIndexAttr(1)); - Value inserted = b.create( + Value inserted = tensor::InsertSliceOp::create(b, loc, result, resultOperands[index], resultOffsets, resultSizes, resultStrides); results.push_back(inserted); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 5bfdbc6d0bb59..1b0bfbbec8624 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -115,7 +115,7 @@ FailureOr mlir::linalg::splitReduction( newShape, cast(operand->get().getType()).getElementType()); - Value newInput = b.create( + Value newInput = tensor::ExpandShapeOp::create(b, loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); } @@ -140,18 +140,18 @@ FailureOr mlir::linalg::splitReduction( } Value emptyOrAllocTensor; if (useAlloc) { - emptyOrAllocTensor = b.create( + emptyOrAllocTensor = bufferization::AllocTensorOp::create(b, loc, RankedTensorType::get(newOutputShape, op.getRegionOutputArgs()[0].getType()), ValueRange{}); } else { - emptyOrAllocTensor = b.create( + emptyOrAllocTensor = tensor::EmptyOp::create(b, loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); } - Value constantOp = b.create(loc, *identity); + Value constantOp = arith::ConstantOp::create(b, loc, *identity); Value identityTensor = - b.create(op->getLoc(), constantOp, emptyOrAllocTensor) + linalg::FillOp::create(b, op->getLoc(), constantOp, emptyOrAllocTensor) .getResult(0); newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, @@ -168,7 +168,7 @@ FailureOr mlir::linalg::splitReduction( } // Create the new op matching the original op with an extra parallel // dimension. - GenericOp genericOp = b.create( + GenericOp genericOp = GenericOp::create(b, loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, ValueRange({identityTensor}), newMaps, newIteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), @@ -191,14 +191,14 @@ FailureOr mlir::linalg::splitReduction( AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); SmallVector reductionMaps = {inputMap, outputMap}; - auto reduction = b.create( + auto reduction = GenericOp::create(b, loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), op.getDpsInits(), reductionMaps, reductionIteratorTypes, [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { Operation *clonedReductionOp = b.clone(*reductionOp); clonedReductionOp->setOperand(0, inputs[0]); clonedReductionOp->setOperand(1, inputs[1]); - b.create(loc, clonedReductionOp->getResult(0)); + linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0)); }); b.replaceOp(op, reduction.getResults()); @@ -318,14 +318,14 @@ FailureOr mlir::linalg::splitReductionByScaling( Value emptyOrAllocTensor; if (useAlloc) { emptyOrAllocTensor = - b.create(loc, newT, dims); + bufferization::AllocTensorOp::create(b, loc, newT, dims); } else { - emptyOrAllocTensor = b.create(loc, newT.getShape(), + emptyOrAllocTensor = tensor::EmptyOp::create(b, loc, newT.getShape(), t.getElementType(), dims); } - Value constantOp = b.create(loc, std::get<1>(it)); + Value constantOp = arith::ConstantOp::create(b, loc, std::get<1>(it)); fillOps.push_back( - b.create(op->getLoc(), constantOp, emptyOrAllocTensor)); + linalg::FillOp::create(b, op->getLoc(), constantOp, emptyOrAllocTensor)); newOutputs.push_back(fillOps.back().getResult(0)); emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp()); } @@ -354,7 +354,7 @@ FailureOr mlir::linalg::splitReductionByScaling( SmallVector newInputs = op.getDpsInputs(); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. - newInputs.push_back(b.create( + newInputs.push_back(tensor::EmptyOp::create(b, loc, ArrayRef{reductionDimSize / splitFactor, splitFactor}, b.getIntegerType(1))); // Output tensors are already good to go. @@ -365,7 +365,7 @@ FailureOr mlir::linalg::splitReductionByScaling( iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, utils::IteratorType::parallel); GenericOp genericOp = - b.create(loc, ValueRange(newOutputs).getTypes(), newInputs, + GenericOp::create(b, loc, ValueRange(newOutputs).getTypes(), newInputs, newOutputs, newMaps, iteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); @@ -396,7 +396,7 @@ FailureOr mlir::linalg::splitReductionByScaling( utils::IteratorType::reduction; // clang-format off - auto reductionOp = b.create( + auto reductionOp = GenericOp::create(b, loc, originalOutputType, reindexedOutput, @@ -407,7 +407,7 @@ FailureOr mlir::linalg::splitReductionByScaling( Operation *clonedReductionOp = b.clone(*combinerOp); clonedReductionOp->setOperand(0, bbArgs[0]); clonedReductionOp->setOperand(1, bbArgs[1]); - b.create(loc, clonedReductionOp->getResult(0)); + linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0)); }); // clang-format on diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp index d35aad514e884..854d3855d654d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp @@ -29,7 +29,7 @@ struct SwapExtractSliceOfFill final if (!fillOp || !fillOp->hasOneUse()) return failure(); - auto newExtractOp = rewriter.create( + auto newExtractOp = tensor::ExtractSliceOp::create(rewriter, extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], extractOp.getMixedOffsets(), extractOp.getMixedSizes(), extractOp.getMixedStrides()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 4741afe8a417d..c2803423fafd7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -94,10 +94,10 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, return; } - Value zero = b.create(0); - Value condition = b.create(arith::CmpIPredicate::sgt, + Value zero = arith::ConstantIndexOp::create(b, 0); + Value condition = arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, cast(value), zero); - b.create( + cf::AssertOp::create(b, condition, b.getStringAttr("expected strictly positive tile size and divisor")); } @@ -317,9 +317,9 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, Value coveredSize = apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, spec.highTileSize, spec.highTripCount}); - Value equals = b.create(arith::CmpIPredicate::eq, + Value equals = arith::CmpIOp::create(b, arith::CmpIPredicate::eq, coveredSize, tripCount); - b.create( + cf::AssertOp::create(b, equals, builder.getStringAttr( "could not compute dynamic multi-size tile shapes")); } @@ -656,7 +656,7 @@ FailureOr linalg::tileReductionUsingForall( getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); // 2. Create the ForallOp with an empty region. - scf::ForallOp forallOp = b.create( + scf::ForallOp forallOp = scf::ForallOp::create(b, loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, mapping); @@ -689,7 +689,7 @@ FailureOr linalg::tileReductionUsingForall( sizes[reductionDim] = b.getIndexAttr(1); outOffsets[reductionDim] = forallOp.getInductionVars()[0]; // TODO: use SubsetExtractOpInterface once it is available. - tiledDpsInitOperands.push_back(b.create( + tiledDpsInitOperands.push_back(tensor::ExtractSliceOp::create(b, loc, cast(initOperand.getType()), destBbArgs[destNum], outOffsets, sizes, strides)); } @@ -768,7 +768,7 @@ FailureOr linalg::tileReductionUsingForall( // 6.b. Parallel insertions are inserted at the end of the combining // terminator. b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); - b.create( + tensor::ParallelInsertSliceOp::create(b, loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 513cecef29b61..94449fae260ee 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -44,7 +44,7 @@ static SmallVector getIndicesForAccess(OpBuilder &b, Location loc, for (auto result : indexingMap.getResults()) { AffineMap m = AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), result); - Value v = b.create(loc, m, ivs); + Value v = affine::AffineApplyOp::create(b, loc, m, ivs); indices.push_back(v); } return indices; @@ -72,7 +72,7 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); auto indices = getIndicesForAccess( b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); - b.create( + memref::StoreOp::create(b, loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), indices); } @@ -351,7 +351,7 @@ struct LinalgOpTilingInterface SmallVector indices = getIndicesForAccess( builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); Value load = - builder.create(linalgOpLoc, operand.get(), indices); + memref::LoadOp::create(builder, linalgOpLoc, operand.get(), indices); indexedValues.push_back(load); } @@ -519,10 +519,10 @@ struct LinalgOpPartialReductionInterface Type elType = getElementTypeOrSelf(result.getType()); Value emptyTensor = - b.create(loc, partialResultShape, elType); - Value constantOp = b.create(loc, *identity); + tensor::EmptyOp::create(b, loc, partialResultShape, elType); + Value constantOp = arith::ConstantOp::create(b, loc, *identity); auto identityTensor = - b.create(loc, constantOp, emptyTensor); + linalg::FillOp::create(b, loc, constantOp, emptyTensor); inits.push_back(identityTensor.getResult(0)); } @@ -574,7 +574,7 @@ struct LinalgOpPartialReductionInterface RankedTensorType sliceResultType = RankedTensorType::get( sliceInfo.resultShape, valueToTileType.getElementType(), valueToTileType.getEncoding()); - auto sliceOp = b.create( + auto sliceOp = tensor::ExtractSliceOp::create(b, loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes, sliceInfo.strides); tiledInits.push_back(sliceOp.getResult()); @@ -603,7 +603,7 @@ struct LinalgOpPartialReductionInterface auto resultTypes = ValueRange(tiledInits).getTypes(); if (tilingStrategy == ReductionTilingStrategy::PartialReductionOuterReduction) { - auto genericOp = b.create( + auto genericOp = GenericOp::create(b, loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes); IRMapping mapping; op->getRegion(0).cloneInto(&genericOp.getRegion(), @@ -648,7 +648,7 @@ struct LinalgOpPartialReductionInterface } } - auto reduction = b.create( + auto reduction = linalg::ReduceOp::create(b, loc, partialResult, init, partialReductionDims, [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) { // Get the combiner op. @@ -659,7 +659,7 @@ struct LinalgOpPartialReductionInterface // Combine the input at idx and output at numInits + idx. clonedReductionOp->setOperand(0, inputs[0]); clonedReductionOp->setOperand(1, inputs[1]); - b.create(loc, clonedReductionOp->getResult(0)); + linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0)); }); mergeOperations.push_back(reduction); @@ -790,7 +790,7 @@ struct PackOpTiling SmallVector strides(inputRank, oneAttr); SmallVector tiledOperands; - auto sourceSlice = b.create( + auto sourceSlice = tensor::ExtractSliceOp::create(b, loc, packOp.getSource(), inputIndices, inputSizes, strides); tiledOperands.push_back(sourceSlice); @@ -800,7 +800,7 @@ struct PackOpTiling return {}; strides.append(packOp.getDestRank() - inputRank, oneAttr); - auto outSlice = b.create( + auto outSlice = tensor::ExtractSliceOp::create(b, loc, packOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(outSlice); @@ -809,7 +809,7 @@ struct PackOpTiling for (auto tile : packOp.getInnerTiles()) tiledOperands.push_back(tile); - Operation *tiledPackOp = b.create( + Operation *tiledPackOp = PackOp::create(b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); return TilingResult{ @@ -968,7 +968,7 @@ struct PackOpTiling SmallVector strides(inputRank, oneAttr); SmallVector tiledOperands; - auto sourceSlice = b.create( + auto sourceSlice = tensor::ExtractSliceOp::create(b, loc, packOp.getSource(), offsets, sizes, strides); tiledOperands.push_back(sourceSlice); @@ -984,7 +984,7 @@ struct PackOpTiling return failure(); strides.append(packOp.getDestRank() - inputRank, oneAttr); - auto outSlice = b.create( + auto outSlice = tensor::ExtractSliceOp::create(b, loc, packOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(outSlice); @@ -992,7 +992,7 @@ struct PackOpTiling for (auto tile : packOp.getInnerTiles()) tiledOperands.push_back(tile); - Operation *tiledPackOp = b.create( + Operation *tiledPackOp = PackOp::create(b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); return TilingResult{ @@ -1172,7 +1172,7 @@ struct UnPackOpTiling sliceSrcSizes.append(unpackOp.getMixedTiles()); sliceSrcStrides.append(numInnerTiles, oneAttr); SmallVector generatedSlices; - tensor::ExtractSliceOp sliceSource = b.create( + tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create(b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes, sliceSrcStrides); generatedSlices.push_back(sliceSource); @@ -1180,12 +1180,12 @@ struct UnPackOpTiling SmallVector destStrides(destRank, oneAttr); Value sliceDest; if (isPerfectTilingCase) { - auto destSliceOp = b.create( + auto destSliceOp = tensor::ExtractSliceOp::create(b, loc, unpackOp.getDest(), offsets, sizes, destStrides); sliceDest = destSliceOp; generatedSlices.push_back(destSliceOp); } else { - sliceDest = b.create( + sliceDest = tensor::EmptyOp::create(b, loc, destExpandedSizes, unpackOp.getDestType().getElementType()); } @@ -1193,7 +1193,7 @@ struct UnPackOpTiling for (auto tile : unpackOp.getInnerTiles()) tiledOperands.push_back(tile); - Operation *tiledUnpackOp = b.create( + Operation *tiledUnpackOp = UnPackOp::create(b, loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); if (isPerfectTilingCase) @@ -1201,7 +1201,7 @@ struct UnPackOpTiling SmallVector(tiledUnpackOp->getResults()), generatedSlices}; - auto extractSlice = b.create( + auto extractSlice = tensor::ExtractSliceOp::create(b, loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); return TilingResult{ @@ -1337,13 +1337,13 @@ struct UnPackOpTiling SmallVector tiledOperands; // Create slice of the dest operand. - auto extractDestSlice = b.create( + auto extractDestSlice = tensor::ExtractSliceOp::create(b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(extractDestSlice); strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); // Create slice of the source operand. - auto extractSourceSlice = b.create( + auto extractSourceSlice = tensor::ExtractSliceOp::create(b, loc, unPackOp.getSource(), offsets, sizes, strides); tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); for (auto tile : unPackOp.getInnerTiles()) @@ -1351,7 +1351,7 @@ struct UnPackOpTiling // Create tiled unpack op. Operation *tiledUnPackOp = - b.create(loc, TypeRange{extractDestSlice.getType()}, + UnPackOp::create(b, loc, TypeRange{extractDestSlice.getType()}, tiledOperands, op->getAttrs()); return TilingResult{{tiledUnPackOp}, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eab74dab4eb75..3526f62f770dc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -269,11 +269,11 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, packingMetadata.reassociations); Value paddingValue = packOp.getPaddingValue(); if (!paddingValue) { - paddingValue = rewriter.create( + paddingValue = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); } auto padOp = - rewriter.create(loc, collapsed, packOp.getSource(), lows, + tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows, highs, paddingValue, /*nofold=*/false); LLVM_DEBUG( @@ -313,7 +313,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, SmallVector sizes = tensor::getMixedSizes(rewriter, loc, packOp.getDest()); - auto insertSliceOp = rewriter.create( + auto insertSliceOp = tensor::InsertSliceOp::create(rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(), /*offsets=*/zeros, sizes, /*strides=*/ones); @@ -329,14 +329,14 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, // 5. Expand from the padded result to the stripMinedShape. auto expandShapeResultType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); - auto reshapeOp = rewriter.create( + auto reshapeOp = tensor::ExpandShapeOp::create(rewriter, loc, expandShapeResultType, padOp.getResult(), packingMetadata.reassociations); // 6. Transpose stripMinedShape to packedShape. SmallVector transpPerm = invertPermutationVector(packedToStripMinedShapePerm); - auto transposeOp = rewriter.create( + auto transposeOp = linalg::TransposeOp::create(rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); @@ -371,7 +371,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, SmallVector sizes(packedRank - destShape.size(), one); sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest())); - auto extractSliceOp = rewriter.create( + auto extractSliceOp = tensor::ExtractSliceOp::create(rewriter, loc, destTensorType, unPackOp.getSource(), SmallVector(packedRank, zero), sizes, SmallVector(packedRank, one)); @@ -404,9 +404,9 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, SmallVector dims = tensor::getMixedSizes(rewriter, loc, unPackOp.getSource()); applyPermutationToVector(dims, packedToStripMinedShapePerm); - auto emptyOp = rewriter.create( + auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims, stripMinedTensorType.getElementType()); - auto transposeOp = rewriter.create( + auto transposeOp = linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); LLVM_DEBUG( @@ -426,20 +426,20 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); // 4. Collapse from the stripMinedShape to the padded result. - auto reshapeOp = rewriter.create( + auto reshapeOp = tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, transposeOp->getResult(0), packingMetadata.reassociations); // 5. ExtractSlice. int64_t destRank = destTensorType.getRank(); - auto extractSliceOp = rewriter.create( + auto extractSliceOp = tensor::ExtractSliceOp::create(rewriter, loc, destTensorType, reshapeOp->getResult(0), SmallVector(destRank, zero), tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), SmallVector(destRank, one)); // 6. Inject a copy to preserve DPS. - auto copyOp = rewriter.create( + auto copyOp = linalg::CopyOp::create(rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest()); // 7. Replace unPackOp by copyOp. @@ -554,15 +554,15 @@ FailureOr linalg::pack(RewriterBase &rewriter, operandType.getShape(), innerPos, cast(dest.getType()).getShape(), {}, innerPackSizes)) { - packOps.push_back(rewriter.create( + packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest, innerPos, innerPackSizes)); } else { // TODO: value of the padding attribute should be determined by // consumers. auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); - Value zero = rewriter.create(loc, zeroAttr); - packOps.push_back(rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); + packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest, innerPos, innerPackSizes, zero)); } inputsAndInits.push_back(packOps.back()); @@ -574,7 +574,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); ValueRange inits = ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); - auto packedLinalgOp = rewriter.create( + auto packedLinalgOp = linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, iteratorTypes); packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); @@ -589,7 +589,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, continue; } // Build the symmetrical UnPackOp to the existing PackOp. - unPackOps.push_back(rewriter.create( + unPackOps.push_back(linalg::UnPackOp::create(rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); results.push_back(unPackOps.back()); @@ -655,7 +655,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace( operands[opOperand.getOperandNumber()] = transposedValue; ValueRange operandsRef(operands); - auto transposedGenericOp = rewriter.create( + auto transposedGenericOp = linalg::GenericOp::create(rewriter, /*location=*/linalgOp->getLoc(), /*resultTensorTypes=*/ operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), @@ -904,7 +904,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { b.setInsertionPointToStart( &op->getParentOfType().getBody().front()); return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { - Value v = b.create(op->getLoc(), s); + Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s); return v; })); }; @@ -926,11 +926,11 @@ Value DecomposePadOpPattern::createFillOrGenerateOp( // Move the padding value defined inside the PadOp block to outside. if (padValue.getParentBlock() == &padOp.getRegion().front()) rewriter.moveOpBefore(padValue.getDefiningOp(), padOp); - return rewriter.create(padOp.getLoc(), padValue, dest).result(); + return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result(); } // Fill could not be optimized: Lower to tensor::GenerateOp with region. - auto generateOp = rewriter.create( + auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(), padOp.getResultType(), dynSizes); // Copy region to new op. IRMapping bvm; @@ -970,7 +970,7 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, } // Init tensor and fill it with padding. - Value emptyTensor = rewriter.create( + Value emptyTensor = tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); @@ -1222,11 +1222,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( applyPermutationToVector(transShapeForEmptyOp, srcPermForTranspose); - Value empty = rewriter.create( + Value empty = tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); // 2.2 Create linalg.transpose - auto transposedOp = rewriter.create(loc, input, empty, + auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); // 3. Insert the inner tile to the destination: @@ -1246,7 +1246,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } // 4. Replace tensor.packOp with tensor.insert_slice created above - auto insert = rewriter.create( + auto insert = tensor::InsertSliceOp::create(rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); rewriter.replaceOp(packOp, insert.getResult()); @@ -1313,7 +1313,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // outer-untiled-dims if (ShapedType::isDynamic(srcShape[i])) { OpFoldResult dynamicDim = - rewriter.create(loc, source, i).getResult(); + tensor::DimOp::create(rewriter, loc, source, i).getResult(); extractSliceSizes.push_back(dynamicDim); shapeForEmptyOp.push_back(dynamicDim); } else { @@ -1340,7 +1340,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); - Value innerTile = rewriter.create( + Value innerTile = tensor::ExtractSliceOp::create(rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets, extractSliceSizes, extractSliceStrides); @@ -1352,9 +1352,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( applyPermutationToVector(shapeForEmptyOp, perm); Value empty = - rewriter.create(loc, shapeForEmptyOp, elemType); + tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType); auto transposedOp = - rewriter.create(loc, innerTile, empty, perm); + linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm); // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. @@ -1369,7 +1369,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); } - auto partialTile = rewriter.create( + auto partialTile = tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); // 4. Insert the result to the destination tensor. @@ -1382,7 +1382,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( else writeSizes.push_back(oneIdxAttr); } - auto insert = rewriter.create( + auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, writeStrides); rewriter.replaceOp(unpackOp, insert.getResult()); @@ -1491,7 +1491,7 @@ FailureOr DownscaleSizeOneWindowed2DConvolution:: dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - auto conv1DOp = rewriter.create( + auto conv1DOp = Conv1DOp::create(rewriter, loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); @@ -1578,7 +1578,7 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - auto conv1DOp = rewriter.create( + auto conv1DOp = DepthwiseConv1DNwcWcOp::create(rewriter, loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); @@ -1635,7 +1635,7 @@ DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); - auto conv1DOp = rewriter.create(loc, newOutputType, + auto conv1DOp = Conv1DOp::create(rewriter, loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp index 092aecceef6b3..b933a2d143925 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp @@ -67,7 +67,7 @@ FailureOr transposeConv2DHelper(RewriterBase &rewriter, Value input; if (isTensorOp) { - input = rewriter.create(loc, newFilterShape, elementTy) + input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy) .getResult(); } else { input = rewriter @@ -78,7 +78,7 @@ FailureOr transposeConv2DHelper(RewriterBase &rewriter, // We can then construct the transposition on our filter. auto transpose = - rewriter.create(loc, filter, input, filterPerm); + linalg::TransposeOp::create(rewriter, loc, filter, input, filterPerm); Value newFilter; if (isTensorOp) { @@ -98,7 +98,7 @@ FailureOr transposeConv2DHelper(RewriterBase &rewriter, resultTy.push_back(op->getResult(0).getType()); } auto newConv = - rewriter.create(loc, resultTy, newInputs, op.getOutputs(), + HWCFConvOp::create(rewriter, loc, resultTy, newInputs, op.getOutputs(), op.getStrides(), op.getDilations()); rewriter.replaceOp(op, newConv); return newConv.getOperation(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index 934781d1cab75..6b1e89c1594d1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -47,24 +47,24 @@ FailureOr mlir::linalg::transposeMatmul(RewriterBase &rewriter, SmallVector dynamicDims; if (type.isDynamicDim(1)) - dynamicDims.push_back(rewriter.create(loc, input, 1)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1)); if (type.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); ArrayRef shape = type.getShape(); - Value empty = rewriter.create( + Value empty = tensor::EmptyOp::create(rewriter, loc, ArrayRef{shape[1], shape[0]}, type.getElementType(), dynamicDims); - auto transposeOp = rewriter.create( + auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty, ArrayRef{1, 0}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = rewriter.create( + newMatmulOp = linalg::MatmulTransposeAOp::create(rewriter, loc, matmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, matmulOp.getOutputs()); } else { - newMatmulOp = rewriter.create( + newMatmulOp = linalg::MatmulTransposeBOp::create(rewriter, loc, matmulOp.getResultTypes(), ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, matmulOp.getOutputs()); @@ -102,26 +102,26 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, SmallVector dynamicDims; if (type.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); if (type.isDynamicDim(2)) - dynamicDims.push_back(rewriter.create(loc, input, 2)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2)); if (type.isDynamicDim(1)) - dynamicDims.push_back(rewriter.create(loc, input, 1)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1)); ArrayRef shape = type.getShape(); - Value empty = rewriter.create( + Value empty = tensor::EmptyOp::create(rewriter, loc, ArrayRef{shape[0], shape[2], shape[1]}, type.getElementType(), dynamicDims); - auto transposeOp = rewriter.create( + auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty, ArrayRef{0, 2, 1}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = rewriter.create( + newMatmulOp = linalg::BatchMatmulTransposeAOp::create(rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, batchMatmulOp.getOutputs()); } else { - newMatmulOp = rewriter.create( + newMatmulOp = linalg::BatchMatmulTransposeBOp::create(rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, batchMatmulOp.getOutputs()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 5a8c5eab3f444..2fff28ffc7474 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -119,7 +119,7 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, SmallVector strides = {1}; for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create( + result.push_back(vector::ExtractStridedSliceOp::create(rewriter, loc, input, /*offsets=*/ArrayRef{w + kw}, sizes, strides)); } } @@ -130,7 +130,7 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, SmallVector strides = {1, 1, 1}; for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create( + result.push_back(vector::ExtractStridedSliceOp::create(rewriter, loc, input, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, sizes, strides)); @@ -149,7 +149,7 @@ static SmallVector extractConvFilterSlices(RewriterBase &rewriter, // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for // non-chanelled convolution] @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - result.push_back(rewriter.create( + result.push_back(vector::ExtractOp::create(rewriter, loc, filter, /*offsets=*/ArrayRef{kw})); } return result; @@ -167,7 +167,7 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, SmallVector sizes = {wSizeStep}; SmallVector strides = {1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create( + result.push_back(vector::ExtractStridedSliceOp::create(rewriter, loc, res, /*offsets=*/ArrayRef{w}, sizes, strides)); } } else { @@ -176,7 +176,7 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, SmallVector sizes = {nSize, wSizeStep, fSize}; SmallVector strides = {1, 1, 1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create( + result.push_back(vector::ExtractStridedSliceOp::create(rewriter, loc, res, /*offsets=*/ArrayRef{0, w, 0}, sizes, strides)); } } @@ -194,7 +194,7 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, // This does not depend on kw. SmallVector strides = {1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create( + res = vector::InsertStridedSliceOp::create(rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef{w}, strides); } } else { @@ -202,7 +202,7 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, // convolution. This does not depend on kw. SmallVector strides = {1, 1, 1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create( + res = vector::InsertStridedSliceOp::create(rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, strides); } @@ -337,7 +337,7 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) { // Create constant index op for static dimensions. - iterSpaceValueSizes.push_back(rewriter.create( + iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); continue; } @@ -351,9 +351,9 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, return failure(); Value dynamicDim = linalgOp.hasPureTensorSemantics() - ? (Value)rewriter.create( + ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand, operandDimPos) - : (Value)rewriter.create( + : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand, operandDimPos); iterSpaceValueSizes.push_back(dynamicDim); } @@ -474,7 +474,7 @@ Value VectorizationState::getOrCreateMaskFor( "Masked 0-d vectors are not supported yet"); // Create the mask based on the dimension values. - Value mask = rewriter.create(linalgOp.getLoc(), + Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(), maskType, upperBounds); LDBG("Creating new mask: " << mask << "\n"); activeMaskCache[maskingMap] = mask; @@ -643,7 +643,7 @@ static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, ArrayRef dimsToMask) { auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); - return b.create( + return vector::MultiDimReductionOp::create(b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); } @@ -689,17 +689,17 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); SmallVector indices(linalgOp.getRank(outputOperand), - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); assert(value.getType() == vectorType && "Incorrect type"); - write = rewriter.create( + write = vector::TransferWriteOp::create(rewriter, loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. if (!isa(value.getType())) - value = rewriter.create(loc, vectorType, value); + value = vector::BroadcastOp::create(rewriter, loc, vectorType, value); assert(value.getType() == vectorType && "Incorrect type"); - write = rewriter.create( + write = vector::TransferWriteOp::create(rewriter, loc, value, outputOperand->get(), ValueRange{}); } @@ -778,7 +778,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, auto indexVectorType = VectorType::get({targetShape[dim]}, rewriter.getIndexType(), state.getScalableVecDims()[dim]); - auto indexSteps = rewriter.create(loc, indexVectorType); + auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType); // Return the one-dimensional index vector if it lives in the trailing // dimension of the iteration space since the vectorization algorithm in this // case can handle the broadcast. @@ -793,14 +793,14 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, auto permMap = AffineMap::getPermutationMap(permPattern, linalgOp.getContext()); - auto broadCastOp = rewriter.create( + auto broadCastOp = vector::BroadcastOp::create(rewriter, loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps); SmallVector transposition = llvm::to_vector<16>(llvm::seq(0, linalgOp.getNumLoops())); std::swap(transposition.back(), transposition[dim]); auto transposeOp = - rewriter.create(loc, broadCastOp, transposition); + vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition); return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp}; } @@ -853,19 +853,19 @@ static Value calculateGatherOffset(RewriterBase &rewriter, const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { - Value dimIdx = rewriter.create(loc, i); + Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i); auto dimSize = broadcastIfNeeded( rewriter, - rewriter.create(loc, extractOp.getTensor(), dimIdx), + tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx), indexVecType); - offset = rewriter.create(loc, offset, dimSize); + offset = arith::MulIOp::create(rewriter, loc, offset, dimSize); auto extractOpIndex = broadcastIfNeeded( rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType); - offset = rewriter.create(loc, extractOpIndex, offset); + offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset); } return offset; @@ -1110,18 +1110,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, // Compute the static loop sizes of the extract op. auto resultType = state.getCanonicalVecType(extractOp.getResult().getType()); - auto maskConstantOp = rewriter.create( + auto maskConstantOp = arith::ConstantOp::create(rewriter, loc, DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()), /*value=*/true)); auto passThruConstantOp = - rewriter.create(loc, rewriter.getZeroAttr(resultType)); + arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(resultType)); // Base indices are currently set to 0. We will need to re-visit if more // generic scenarios are to be supported. SmallVector baseIndices( extractOp.getIndices().size(), - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); VectorMemoryAccessKind memAccessKind = getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType); @@ -1131,7 +1131,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); // Generate the gather load - Operation *gatherOp = rewriter.create( + Operation *gatherOp = vector::GatherOp::create(rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); @@ -1166,13 +1166,13 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, continue; } - auto indexAs1dVector = rewriter.create( + auto indexAs1dVector = vector::ShapeCastOp::create(rewriter, loc, VectorType::get(resultType.getShape().back(), rewriter.getIndexType(), resultType.getScalableDims().back()), idx); transferReadIdxs.push_back( - rewriter.create(loc, indexAs1dVector, 0)); + vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0)); } // `tensor.extract_element` is always in-bounds, hence the following holds. @@ -1186,7 +1186,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, SmallVector exprs(dstRank, getAffineConstantExpr(0, ctx)); auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); - auto transferReadOp = rewriter.create( + auto transferReadOp = vector::TransferReadOp::create(rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs, /*padding=*/std::nullopt, permutationMap, inBounds); @@ -1195,7 +1195,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, // valid here). SmallVector readMaskShape = {1}; auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type()); - auto allTrue = rewriter.create( + auto allTrue = vector::ConstantMaskOp::create(rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue); auto *maskedReadOp = mlir::vector::maskOperation(rewriter, transferReadOp, allTrue); @@ -1223,7 +1223,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, rankDiff--; } - auto transferReadOp = rewriter.create( + auto transferReadOp = vector::TransferReadOp::create(rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs, /*padding=*/std::nullopt, permutationMap, inBounds); @@ -1405,7 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); if (linalgOp.isScalar(opOperand)) { @@ -1435,7 +1435,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, SmallVector indices(linalgOp.getShape(opOperand).size(), zero); - Operation *read = rewriter.create( + Operation *read = vector::TransferReadOp::create(rewriter, loc, readType, opOperand->get(), indices, /*padding=*/std::nullopt, readMap); read = state.maskOperation(rewriter, read, linalgOp, indexingMap); @@ -1452,7 +1452,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readType.getRank() == 0) - readValue = rewriter.create(loc, readValue, + readValue = vector::ExtractOp::create(rewriter, loc, readValue, ArrayRef()); LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue @@ -1660,13 +1660,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, writeIndices.size() == static_cast(destRank)) && "Invalid number of write indices!"); if (writeIndices.empty()) { - auto zero = builder.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); writeIndices.assign(destRank, zero); } // Generate the xfer_write Op Operation *write = - builder.create(loc, + vector::TransferWriteOp::create(builder, loc, /*vector=*/vecToStore, /*source=*/dest, /*indices=*/writeIndices, @@ -1742,7 +1742,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, Location loc = packOp.getLoc(); auto padValue = packOp.getPaddingValue(); if (!padValue) { - padValue = rewriter.create( + padValue = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); } ReifiedRankedShapedTypeDims reifiedReturnShapes; @@ -1782,16 +1782,16 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), packOp.getDestType().getElementType()); auto shapeCastOp = - rewriter.create(loc, tiledPackType, maskedRead); + vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead); // Create TransposeOp. auto destPermutation = invertPermutationVector(getPackInverseDestPerm(packOp)); - auto transposeOp = rewriter.create( + auto transposeOp = vector::TransposeOp::create(rewriter, loc, shapeCastOp.getResult(), destPermutation); // Create TransferWriteOp. - Value dest = rewriter.create( + Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], transposeOp.getResult().getType().getElementType()); Operation *write = @@ -1892,7 +1892,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, } Location loc = unpackOp->getLoc(); - auto padValue = rewriter.create( + auto padValue = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); // Read result, mask if necessary. If transferReadOp shape is not equal @@ -1911,7 +1911,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, RankedTensorType stripMineTensorType = RankedTensorType::get(stripMineShape, stripMineElemType); // Transpose the appropriate rows to match output. - vector::TransposeOp transposeOp = rewriter.create( + vector::TransposeOp transposeOp = vector::TransposeOp::create(rewriter, loc, readResult, lastDimToInsertPosPerm); // Collapse the vector to the size required by result. @@ -1919,7 +1919,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, stripMineTensorType, packMetadata.reassociations); mlir::VectorType vecCollapsedType = VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); - vector::ShapeCastOp shapeCastOp = rewriter.create( + vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); // writeVectorSizes had to match the shapecast shape for dynamic sizes, @@ -1928,7 +1928,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, unpackOp.getDestType().hasStaticShape() ? vectorSizes : shapeCastOp.getResultVectorType().getShape()); - Value dest = rewriter.create( + Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedRetShapes[0], shapeCastOp.getResult().getType().getElementType()); Operation *write = createWriteOrMaskedWrite( @@ -1963,7 +1963,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, /*useInBoundsInsteadOfMasking=*/false); // Create Xfer write Op - Value dest = rewriter.create( + Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], padOp.getResultType().getElementType()); Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest); newResults.push_back(write->getResult(0)); @@ -2634,19 +2634,19 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, auto writeType = VectorType::get(dstType.getShape(), dstElementType); Location loc = copyOp->getLoc(); - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector indices(srcType.getRank(), zero); - Value readValue = rewriter.create( + Value readValue = vector::TransferReadOp::create(rewriter, loc, readType, copyOp.getSource(), indices, /*padding=*/std::nullopt, rewriter.getMultiDimIdentityMap(srcType.getRank())); if (cast(readValue.getType()).getRank() == 0) { readValue = - rewriter.create(loc, readValue, ArrayRef()); - readValue = rewriter.create(loc, writeType, readValue); + vector::ExtractOp::create(rewriter, loc, readValue, ArrayRef()); + readValue = vector::BroadcastOp::create(rewriter, loc, writeType, readValue); } - Operation *writeValue = rewriter.create( + Operation *writeValue = vector::TransferWriteOp::create(rewriter, loc, readValue, copyOp.getTarget(), indices, rewriter.getMultiDimIdentityMap(srcType.getRank())); rewriter.replaceOp(copyOp, writeValue->getResults()); @@ -2957,7 +2957,7 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, if (!padValue) { auto elemType = sourceType.getElementType(); - padValue = rewriter.create( + padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); } @@ -2989,7 +2989,7 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, // Create read SmallVector readIndices( - vecType.getRank(), rewriter.create(loc, 0)); + vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( rewriter, loc, source, vecType.getShape(), padValue, /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); @@ -3076,8 +3076,8 @@ struct PadOpVectorizationWithInsertSlicePattern // Generate TransferReadOp: Read entire source tensor and add high // padding. SmallVector readIndices( - vecRank, rewriter.create(padOp.getLoc(), 0)); - auto read = rewriter.create( + vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0)); + auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at @@ -3212,7 +3212,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( // When forwarding to vector.transfer_read, the attribute must be reset // conservatively. auto vectorType = xferOp.getVectorType(); - Value res = rewriter.create( + Value res = vector::TransferReadOp::create(rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), rewriter.getBoolArrayAttr( @@ -3271,7 +3271,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( // When forwarding to vector.transfer_write, the attribute must be reset // conservatively. auto vector = xferOp.getVector(); - rewriter.create( + vector::TransferWriteOp::create(rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getMask(), rewriter.getBoolArrayAttr(SmallVector( @@ -3467,7 +3467,7 @@ struct Conv1DGenerator } vector::TransferWriteOp write; - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -3486,16 +3486,16 @@ struct Conv1DGenerator SmallVector resPadding(resShape.size(), zero); // Read the whole lhs, rhs and res in one shot (with zero padding). - Value lhs = rewriter.create( + Value lhs = vector::TransferReadOp::create(rewriter, loc, lhsType, lhsShaped, lhsPadding, /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); // This is needed only for Conv. Value rhs = nullptr; if (oper == ConvOperationKind::Conv) - rhs = rewriter.create( + rhs = vector::TransferReadOp::create(rewriter, loc, rhsType, rhsShaped, rhsPadding, /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); - Value res = rewriter.create( + Value res = vector::TransferReadOp::create(rewriter, loc, resType, resShaped, resPadding, /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); @@ -3511,16 +3511,16 @@ struct Conv1DGenerator // To match base vectorization case, we pre-transpose current case. // ncw -> nwc static constexpr std::array permLhs = {0, 2, 1}; - lhs = rewriter.create(loc, lhs, permLhs); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; // This is needed only for Conv. if (oper == ConvOperationKind::Conv) - rhs = rewriter.create(loc, rhs, permRhs); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; - res = rewriter.create(loc, res, permRes); + res = vector::TransposeOp::create(rewriter, loc, res, permRes); break; } } @@ -3585,7 +3585,7 @@ struct Conv1DGenerator case Conv1DOpOrder::Ncw: { // nwf -> nfw static constexpr std::array perm = {0, 2, 1}; - res = rewriter.create(loc, res, perm); + res = vector::TransposeOp::create(rewriter, loc, res, perm); break; } } @@ -3609,16 +3609,16 @@ struct Conv1DGenerator cast(val.getType()).cloneWith(std::nullopt, dstElementType); if (isa(srcElementType) && isa(dstElementType)) { - return rewriter.create(loc, dstType, val); + return arith::SIToFPOp::create(rewriter, loc, dstType, val); } if (isa(srcElementType) && isa(dstElementType) && srcWidth < dstWidth) - return rewriter.create(loc, dstType, val); + return arith::ExtFOp::create(rewriter, loc, dstType, val); if (isa(srcElementType) && isa(dstElementType) && srcWidth < dstWidth) - return rewriter.create(loc, dstType, val); + return arith::ExtSIOp::create(rewriter, loc, dstType, val); assert(false && "unhandled promotion case"); return nullptr; @@ -3633,7 +3633,7 @@ struct Conv1DGenerator bindDims(ctx, n, w, f, c); lhs = promote(rewriter, loc, lhs, res.getType()); rhs = promote(rewriter, loc, rhs, res.getType()); - auto contrationOp = rewriter.create( + auto contrationOp = vector::ContractionOp::create(rewriter, loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); @@ -3645,7 +3645,7 @@ struct Conv1DGenerator // convolution. Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { - return rewriter.create( + return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); } @@ -3693,7 +3693,7 @@ struct Conv1DGenerator bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; - Value zero = rewriter.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -3736,28 +3736,28 @@ struct Conv1DGenerator cast(op).hasPureTensorSemantics(), opToMask, rewriter); Value maskOp = - rewriter.create(loc, maskType, mixedDims); + vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims); return mlir::vector::maskOperation(rewriter, opToMask, maskOp); }; // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. - Value lhs = rewriter.create( + Value lhs = vector::TransferReadOp::create(rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); auto maybeMaskedLhs = maybeMaskXferOp( lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = rewriter.create( + Value rhs = vector::TransferReadOp::create(rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); auto maybeMaskedRhs = maybeMaskXferOp( rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); // Read res slice of size {n, w, c} @ [0, 0, 0]. - Value res = rewriter.create( + Value res = vector::TransferReadOp::create(rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); auto maybeMaskedRes = maybeMaskXferOp( @@ -3775,7 +3775,7 @@ struct Conv1DGenerator // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(rewriter.create( + lhsVals.push_back(vector::ExtractStridedSliceOp::create(rewriter, loc, maybeMaskedLhs->getResult(0), /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, inOutSliceSizes, inOutStrides)); @@ -3783,13 +3783,13 @@ struct Conv1DGenerator } // Extract rhs slice of size {c} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(rewriter.create( + rhsVals.push_back(vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0), /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(rewriter.create( + resVals.push_back(vector::ExtractStridedSliceOp::create(rewriter, loc, maybeMaskedRes->getResult(0), /*offsets=*/ArrayRef{0, w, 0}, inOutSliceSizes, inOutStrides)); @@ -3815,16 +3815,16 @@ struct Conv1DGenerator if (flatten) { // Flatten the input and output vectors (collapse the channel // dimension) - lhsVal = rewriter.create( + lhsVal = vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]); - resVal = rewriter.create( + resVal = vector::ShapeCastOp::create(rewriter, loc, resTypeAfterFlattening, resVals[w]); } resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal, rhsVals[kw], resVal, flatten); if (flatten) { // Un-flatten the output vector (restore the channel dimension) - resVals[w] = rewriter.create( + resVals[w] = vector::ShapeCastOp::create(rewriter, loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); } } @@ -3843,7 +3843,7 @@ struct Conv1DGenerator // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - maybeMaskedRes = rewriter.create( + maybeMaskedRes = vector::InsertStridedSliceOp::create(rewriter, loc, resVals[w], maybeMaskedRes->getResult(0), /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -3853,7 +3853,7 @@ struct Conv1DGenerator //===------------------------------------------------------------------===// // Write back res slice of size {n, w, c} @ [0, 0, 0]. - Operation *resOut = rewriter.create( + Operation *resOut = vector::TransferWriteOp::create(rewriter, loc, maybeMaskedRes->getResult(0), resShaped, ValueRange{zero, zero, zero}); return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(), @@ -3891,10 +3891,10 @@ struct Conv1DGenerator indices.push_back(j); } - rhs = rewriter.create(loc, rhs, rhs, indices); + rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices); } // Broadcast the filter to match the output vector - rhs = rewriter.create( + rhs = vector::BroadcastOp::create(rewriter, loc, resTy.clone(rhsTy.getElementType()), rhs); rhs = promote(rewriter, loc, rhs, resTy); @@ -3903,10 +3903,10 @@ struct Conv1DGenerator return nullptr; if (isa(resTy.getElementType())) - return rewriter.create(loc, lhs, rhs, res); + return vector::FMAOp::create(rewriter, loc, lhs, rhs, res); - auto mul = rewriter.create(loc, lhs, rhs); - return rewriter.create(loc, mul, res); + auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs); + return arith::AddIOp::create(rewriter, loc, mul, res); } /// Entry point for non-channeled convolution: diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 9fd084487e3fd..f9c6aae9374a8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -201,7 +201,7 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc, TransformMatrix transform, Type type) { ArrayRef constVec(transform.table, transform.rows * transform.cols); - return builder.create( + return arith::ConstantOp::create(builder, loc, DenseFPElementsAttr::get( RankedTensorType::get( SmallVector{transform.rows, transform.cols}, type), @@ -233,7 +233,7 @@ Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source, auto extractFilterType = RankedTensorType::get({extractHeight, extractWidth}, elementType); - auto extractFilterOp = builder.create( + auto extractFilterOp = tensor::ExtractSliceOp::create(builder, loc, extractFilterType, source, offsets, sizes, strides); return extractFilterOp; @@ -267,7 +267,7 @@ Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source, SmallVector strides(srcSize, oneIndex); auto extractFilterType = RankedTensorType::get({height, width}, elementType); - auto extractFilterOp = builder.create( + auto extractFilterOp = tensor::ExtractSliceOp::create(builder, loc, extractFilterType, source, offsets, sizes, strides); return extractFilterOp; @@ -293,7 +293,7 @@ Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, retSizes[widthIdx] = builder.getIndexAttr(width); SmallVector strides(destSize, oneIndex); - auto insertSliceOp = builder.create( + auto insertSliceOp = tensor::InsertSliceOp::create(builder, loc, source, dest, retOffsets, retSizes, strides); return insertSliceOp; @@ -321,7 +321,7 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, retSizes[widthIdx] = builder.getIndexAttr(width); SmallVector strides(destSize, oneIndex); - auto insertSliceOp = builder.create( + auto insertSliceOp = tensor::InsertSliceOp::create(builder, loc, source, dest, retOffsets, retSizes, strides); return insertSliceOp; @@ -372,7 +372,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, if (filterW != r && filterW != 1) return Value(); - Value zeroIdx = rewriter.create(loc, 0); + Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) -> scf::ValueVector { Value FIter = ivs[0]; @@ -386,7 +386,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, int64_t retRows = 1; Value matmulRetValue = extractFilter; - Value zero = builder.create( + Value zero = arith::ConstantOp::create(builder, loc, rewriter.getZeroAttr(elementType)); if (leftTransform) { // Get constant transform matrix G. @@ -401,11 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, builder .create(loc, matmulType.getShape(), elementType) .getResult(); - auto init = builder.create(loc, zero, empty).getResult(0); + auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType); // Multiply G x g. - auto matmulOp = builder.create( + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -423,11 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, builder .create(loc, matmulType.getShape(), elementType) .getResult(); - auto init = builder.create(loc, zero, empty).getResult(0); + auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType); // Multiply u = (G x g) x GT. - auto matmulOp = builder.create( + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -445,9 +445,9 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, return {insertSliceOp}; }; - auto fUpperBound = rewriter.create(loc, filterF); - auto cUpperBound = rewriter.create(loc, filterC); - auto oneStep = rewriter.create(loc, 1); + auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterF); + auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterC); + auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1); scf::LoopNest loops = scf::buildLoopNest( rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, {oneStep, oneStep}, {retValue}, buildBody); @@ -516,9 +516,9 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); auto affineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); - Value heightOffset = builder.create( + Value heightOffset = affine::AffineApplyOp::create(builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter); - Value widthOffset = builder.create( + Value widthOffset = affine::AffineApplyOp::create(builder, loc, rightTransform ? affineMap : identityAffineMap, tileWIter); // Extract (H, W) from (N, H, W, C). @@ -530,7 +530,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, int64_t retRows = 1; int64_t retCols = 1; Value matmulRetValue = extractInput; - Value zero = builder.create( + Value zero = arith::ConstantOp::create(builder, loc, rewriter.getZeroAttr(elementType)); if (leftTransform) { // Get constant transform matrix BT. @@ -545,12 +545,12 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, builder .create(loc, matmulType.getShape(), elementType) .getResult(); - auto init = builder.create(loc, zero, empty).getResult(0); + auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value BT = create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); // Multiply BT x d. - auto matmulOp = builder.create( + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -568,11 +568,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, builder .create(loc, matmulType.getShape(), elementType) .getResult(); - auto init = builder.create(loc, zero, empty).getResult(0); + auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value B = create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); // Multiply v = (BT x d) x B. - auto matmulOp = builder.create( + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -586,12 +586,12 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, return {combinedVal}; }; - auto zeroIdx = rewriter.create(loc, 0); - auto tileHBound = rewriter.create(loc, tileH); - auto tileWBound = rewriter.create(loc, tileW); - auto nUpperBound = rewriter.create(loc, inputN); - auto cUpperBound = rewriter.create(loc, inputC); - auto oneStep = rewriter.create(loc, 1); + auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tileH); + auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW); + auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputN); + auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputC); + auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1); scf::LoopNest loops = scf::buildLoopNest( rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, {tileHBound, tileWBound, nUpperBound, cUpperBound}, @@ -629,7 +629,7 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, filterElementType); SmallVector filterReassoc = {{0, 1}, {2}, {3}}; - Value collapseFilter = rewriter.create( + Value collapseFilter = tensor::CollapseShapeOp::create(rewriter, loc, filterReassocType, transformedFilter, filterReassoc); // Convert (alphaH, alphaW, tileH, tileW, N, C) to @@ -643,7 +643,7 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, inputElementType); SmallVector inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; - Value collapseInput = rewriter.create( + Value collapseInput = tensor::CollapseShapeOp::create(rewriter, loc, inputReassocType, transformedInput, inputReassoc); // Batched matrix multiply. @@ -655,11 +655,11 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, .create(loc, matmulType.getShape(), outputElementType) .getResult(); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(outputElementType)); - Value init = rewriter.create(loc, zero, empty).getResult(0); + Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0); - auto matmulOp = rewriter.create( + auto matmulOp = linalg::BatchMatmulOp::create(rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}), ValueRange{init}); @@ -670,7 +670,7 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], inputShape[3], inputShape[4], filterShape[3]}, outputElementType); - auto expandOutput = rewriter.create( + auto expandOutput = tensor::ExpandShapeOp::create(rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc); return expandOutput; } @@ -750,15 +750,15 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, int64_t retRows = leftTransform ? ATMatrix.rows : 1; Value matmulRetValue = extractValue; - Value zero = builder.create( + Value zero = arith::ConstantOp::create(builder, loc, rewriter.getZeroAttr(elementType)); auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); auto affineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); - Value heightOffset = builder.create( + Value heightOffset = affine::AffineApplyOp::create(builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter); - Value widthOffset = builder.create( + Value widthOffset = affine::AffineApplyOp::create(builder, loc, rightTransform ? affineMap : identityAffineMap, tileWIter); Value outInitVal = @@ -775,12 +775,12 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, .create(loc, matmulType.getShape(), elementType) .getResult(); - init = builder.create(loc, zero, empty).getResult(0); + init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); } Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType); // Multiply AT x m. - auto matmulOp = builder.create( + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -794,19 +794,19 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, .create(loc, matmulType.getShape(), elementType) .getResult(); - init = builder.create(loc, zero, empty).getResult(0); + init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); } Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType); // Multiply y = (AT x m) x A. - auto matmulOp = builder.create( + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } if (scalarFactor != 1) { // Multiply by scalar factor and add outInitVal. - Value scalarFactorValue = builder.create( + Value scalarFactorValue = arith::ConstantOp::create(builder, loc, FloatAttr::get(elementType, scalarFactor)); auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); @@ -824,11 +824,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, utils::IteratorType::parallel}, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - auto mulf = nestedBuilder.create( + auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc, args[0], args[1]); - auto addf = nestedBuilder.create( + auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc, mulf.getResult(), args[2]); - nestedBuilder.create(nestedLoc, + linalg::YieldOp::create(nestedBuilder, nestedLoc, addf.getResult()); }) .getResult(0); @@ -847,12 +847,12 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, int64_t tilwH = valueShape[2]; int64_t tileW = valueShape[3]; - auto zeroIdx = rewriter.create(loc, 0); - auto tileHBound = rewriter.create(loc, tilwH); - auto tileWBound = rewriter.create(loc, tileW); - auto nUpperBound = rewriter.create(loc, valueN); - auto fUpperBound = rewriter.create(loc, valueF); - auto oneStep = rewriter.create(loc, 1); + auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tilwH); + auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW); + auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueN); + auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueF); + auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1); scf::LoopNest loops = scf::buildLoopNest( rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, {tileHBound, tileWBound, nUpperBound, fUpperBound}, @@ -867,7 +867,7 @@ static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, auto valueType = cast(value.getType()); Type elementType = valueType.getElementType(); auto alignedType = RankedTensorType::get(alignedShape, elementType); - Value padValue = rewriter.create( + Value padValue = arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value, @@ -887,7 +887,7 @@ static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, SmallVector sizes = getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); - return rewriter.create(loc, extractedType, value, + return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value, offsets, sizes, strides); } @@ -979,9 +979,9 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, int64_t tileW = llvm::divideCeilSigned(outputW, widthM); auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, filterElementType); - Value retValue = rewriter.create(loc, retType.getShape(), + Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(), filterElementType); - auto transformedFilter = rewriter.create( + auto transformedFilter = linalg::WinogradFilterTransformOp::create(rewriter, loc, retType, filter, retValue, fmr); // --- Create operation for input transform --- @@ -998,9 +998,9 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, retType = RankedTensorType::get( {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); - retValue = rewriter.create(loc, retType.getShape(), + retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(), inputElementType); - auto transformedInput = rewriter.create( + auto transformedInput = linalg::WinogradInputTransformOp::create(rewriter, loc, retType, input, retValue, fmr); Type outputElementType = outputType.getElementType(); @@ -1023,7 +1023,7 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, outputType = alignedOutputType; } - Value transformedOutput = rewriter.create( + Value transformedOutput = linalg::WinogradOutputTransformOp::create(rewriter, loc, outputType, matmulRet, output, fmr); // When output size is not aligned with output tile size, extract the diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index e89e80b07d20d..d964aa5cb4e81 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -320,14 +320,14 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); SmallVector iteratorTypes(memrefTypeTo.getRank(), utils::IteratorType::parallel); - return b.create( + return linalg::GenericOp::create(b, loc, /*inputs=*/from, /*outputs=*/to, /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args.front()); + linalg::YieldOp::create(b, loc, args.front()); }); } @@ -483,7 +483,7 @@ static void generateParallelLoopNest( case DistributionMethod::None: { // Generate a single parallel loop-nest operation for all outermost // parallel loops and recurse. - b.create( + scf::ParallelOp::create(b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), steps.take_front(numProcessed), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { @@ -499,7 +499,7 @@ static void generateParallelLoopNest( case DistributionMethod::Cyclic: { // Generate a single parallel loop-nest operation for all outermost // parallel loops and recurse. - b.create( + scf::ParallelOp::create(b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), steps.take_front(numProcessed), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { @@ -519,13 +519,13 @@ static void generateParallelLoopNest( for (unsigned i = 1; i < numProcessed; ++i) cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); - b.create(loc, cond, [&](OpBuilder &b, Location loc) { + scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) { generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), remainderProcInfo, bodyBuilderFn, ivStorage); - b.create(loc, ValueRange{}); + scf::YieldOp::create(b, loc, ValueRange{}); }); return; } @@ -595,12 +595,12 @@ static Operation *materializeTiledShape(OpBuilder &builder, Location loc, auto shapedType = dyn_cast(valueToTile.getType()); auto *sliceOp = TypeSwitch(shapedType) .Case([&](MemRefType) { - return builder.create( + return memref::SubViewOp::create(builder, loc, valueToTile, sliceParams.offsets, sliceParams.sizes, sliceParams.strides); }) .Case([&](RankedTensorType) { - return builder.create( + return tensor::ExtractSliceOp::create(builder, loc, valueToTile, sliceParams.offsets, sliceParams.sizes, sliceParams.strides); }) @@ -793,7 +793,7 @@ SmallVector insertSlicesBack(OpBuilder &builder, Location loc, // `tiledOperands`. Value outputTensor = operands[opOperand.getOperandNumber()]; if (auto sliceOp = outputTensor.getDefiningOp()) { - Value inserted = builder.create( + Value inserted = tensor::InsertSliceOp::create(builder, loc, sliceOp.getSource().getType(), results[resultIdx], sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp index 37ddca101f64b..877f8e72caa31 100644 --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/FunctionImplementation.h" using namespace mlir; diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp index ff6af63eee531..9dadd9bea2f67 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp @@ -135,7 +135,7 @@ struct GlobalStoreOpInterface auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); auto loc = globalStoreOp.getLoc(); - auto targetMemref = rewriter.create( + auto targetMemref = memref::GetGlobalOp::create(rewriter, loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); auto sourceMemref = diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index 7940ff60a48e7..da2317cd4eb97 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; @@ -60,8 +61,8 @@ struct FoldRank final : public mlir::OpRewritePattern { if (!isa(dltiAttr.value())) return op->emitError() << "Expected an integer attribute for MPI:comm_world_rank"; - Value res = b.create( - op.getLoc(), cast(dltiAttr.value()).getInt()); + Value res = arith::ConstantIndexOp::create( + b, op.getLoc(), cast(dltiAttr.value()).getInt()); if (Value retVal = op.getRetval()) b.replaceOp(op, {retVal, res}); else diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 26441a9d78658..b09a67818fff6 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include using namespace mlir; @@ -746,7 +747,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 13e2a4b5541b2..770f32841c0f6 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&](Value value) -> Value { if (auto vec = dyn_cast(op.getType())) - return rewriter.create(op.getLoc(), vec, value); + return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value); return value; }; @@ -84,14 +84,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Replace `pow(x, 3.0)` with `x * x * x`. if (isExponentValue(3.0)) { Value square = - rewriter.create(op.getLoc(), ValueRange({x, x})); + arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x})); rewriter.replaceOpWithNewOp(op, ValueRange({x, square})); return success(); } // Replace `pow(x, -1.0)` with `1.0 / x`. if (isExponentValue(-1.0)) { - Value one = rewriter.create( + Value one = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); rewriter.replaceOpWithNewOp(op, ValueRange({bcast(one), x})); return success(); @@ -111,8 +111,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`. if (isExponentValue(0.75)) { - Value powHalf = rewriter.create(op.getLoc(), x); - Value powQuarter = rewriter.create(op.getLoc(), powHalf); + Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x); + Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf); rewriter.replaceOpWithNewOp(op, ValueRange{powHalf, powQuarter}); return success(); @@ -168,17 +168,17 @@ PowIStrengthReduction::matchAndRewrite( // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&loc, &op, &rewriter](Value value) -> Value { if (auto vec = dyn_cast(op.getType())) - return rewriter.create(loc, vec, value); + return vector::BroadcastOp::create(rewriter, loc, vec, value); return value; }; Value one; Type opType = getElementTypeOrSelf(op.getType()); if constexpr (std::is_same_v) - one = rewriter.create( + one = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(opType, 1.0)); else - one = rewriter.create( + one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(opType, 1)); // Replace `[fi]powi(x, 0)` with `1`. @@ -208,12 +208,12 @@ PowIStrengthReduction::matchAndRewrite( // with: // (1 / x) * (1 / x) * (1 / x) * ... for (unsigned i = 1; i < exponentValue; ++i) - result = rewriter.create(loc, result, base); + result = MulOpTy::create(rewriter, loc, result, base); // Inverse the base for negative exponent, i.e. for // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. if (exponentIsNegative) - result = rewriter.create(loc, bcast(one), result); + result = DivOpTy::create(rewriter, loc, bcast(one), result); rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 7b5350ca26b60..419ba0d2e500b 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -33,11 +33,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value, APFloat::rmNearestTiesToEven, &losesInfo); auto attr = b.getFloatAttr(eltType, value); if (auto shapedTy = dyn_cast(type)) { - return b.create(loc, + return arith::ConstantOp::create(b, loc, DenseElementsAttr::get(shapedTy, attr)); } - return b.create(loc, attr); + return arith::ConstantOp::create(b, loc, attr); } static Value createFloatConst(Location loc, Type type, double value, @@ -50,11 +50,11 @@ static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b) { auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { - return b.create(loc, + return arith::ConstantOp::create(b, loc, DenseElementsAttr::get(shapedTy, attr)); } - return b.create(loc, attr); + return arith::ConstantOp::create(b, loc, attr); } static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { @@ -62,11 +62,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { Type i64Ty = b.getI64Type(); if (auto shapedTy = dyn_cast(opType)) i64Ty = shapedTy.clone(i64Ty); - Value fixedConvert = b.create(i64Ty, operand); - Value fpFixedConvert = b.create(opType, fixedConvert); + Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand); + Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert); // The truncation does not preserve the sign when the truncated // value is -0. So here the sign is copied again. - return b.create(fpFixedConvert, operand); + return math::CopySignOp::create(b, fpFixedConvert, operand); } // sinhf(float x) -> (exp(x) - exp(-x)) / 2 @@ -75,12 +75,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { Value operand = op.getOperand(); Type opType = operand.getType(); - Value exp = b.create(operand); - Value neg = b.create(operand); - Value nexp = b.create(neg); - Value sub = b.create(exp, nexp); + Value exp = math::ExpOp::create(b, operand); + Value neg = arith::NegFOp::create(b, operand); + Value nexp = math::ExpOp::create(b, neg); + Value sub = arith::SubFOp::create(b, exp, nexp); Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value res = b.create(sub, half); + Value res = arith::MulFOp::create(b, sub, half); rewriter.replaceOp(op, res); return success(); } @@ -91,12 +91,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { Value operand = op.getOperand(); Type opType = operand.getType(); - Value exp = b.create(operand); - Value neg = b.create(operand); - Value nexp = b.create(neg); - Value add = b.create(exp, nexp); + Value exp = math::ExpOp::create(b, operand); + Value neg = arith::NegFOp::create(b, operand); + Value nexp = math::ExpOp::create(b, neg); + Value add = arith::AddFOp::create(b, exp, nexp); Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value res = b.create(add, half); + Value res = arith::MulFOp::create(b, add, half); rewriter.replaceOp(op, res); return success(); } @@ -117,23 +117,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter); // Compute sign(x) = cast(x < 0) * (-2) + 1 - Value isNegative = rewriter.create( + Value isNegative = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); Value isNegativeFloat = - rewriter.create(loc, floatType, isNegative); + arith::UIToFPOp::create(rewriter, loc, floatType, isNegative); Value isNegativeTimesNegTwo = - rewriter.create(loc, isNegativeFloat, negTwo); - Value sign = rewriter.create(loc, isNegativeTimesNegTwo, one); + arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo); + Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one); // Normalize input to positive value: y = sign(x) * x - Value positiveX = rewriter.create(loc, sign, op.getOperand()); + Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand()); // Decompose on normalized input - Value negDoubledX = rewriter.create(loc, negTwo, positiveX); - Value exp2x = rewriter.create(loc, negDoubledX); - Value dividend = rewriter.create(loc, one, exp2x); - Value divisor = rewriter.create(loc, one, exp2x); - Value positiveRes = rewriter.create(loc, dividend, divisor); + Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX); + Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX); + Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x); + Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x); + Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor); // Multiply result by sign(x) to retain signs from negative inputs rewriter.replaceOpWithNewOp(op, sign, positiveRes); @@ -146,9 +146,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type type = operand.getType(); - Value sin = b.create(type, operand); - Value cos = b.create(type, operand); - Value div = b.create(type, sin, cos); + Value sin = math::SinOp::create(b, type, operand); + Value cos = math::CosOp::create(b, type, operand); + Value div = arith::DivFOp::create(b, type, sin, cos); rewriter.replaceOp(op, div); return success(); } @@ -161,10 +161,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op, Type opType = operand.getType(); Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); - Value fma = b.create(operand, operand, one); - Value sqrt = b.create(fma); - Value add = b.create(operand, sqrt); - Value res = b.create(add); + Value fma = math::FmaOp::create(b, operand, operand, one); + Value sqrt = math::SqrtOp::create(b, fma); + Value add = arith::AddFOp::create(b, operand, sqrt); + Value res = math::LogOp::create(b, add); rewriter.replaceOp(op, res); return success(); } @@ -177,10 +177,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op, Type opType = operand.getType(); Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter); - Value fma = b.create(operand, operand, negOne); - Value sqrt = b.create(fma); - Value add = b.create(operand, sqrt); - Value res = b.create(add); + Value fma = math::FmaOp::create(b, operand, operand, negOne); + Value sqrt = math::SqrtOp::create(b, fma); + Value add = arith::AddFOp::create(b, operand, sqrt); + Value res = math::LogOp::create(b, add); rewriter.replaceOp(op, res); return success(); } @@ -193,13 +193,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op, Type opType = operand.getType(); Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); - Value add = b.create(operand, one); - Value neg = b.create(operand); - Value sub = b.create(neg, one); - Value div = b.create(add, sub); - Value log = b.create(div); + Value add = arith::AddFOp::create(b, operand, one); + Value neg = arith::NegFOp::create(b, operand); + Value sub = arith::AddFOp::create(b, neg, one); + Value div = arith::DivFOp::create(b, add, sub); + Value log = math::LogOp::create(b, div); Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value res = b.create(log, half); + Value res = arith::MulFOp::create(b, log, half); rewriter.replaceOp(op, res); return success(); } @@ -210,8 +210,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { Value operandB = op.getOperand(1); Value operandC = op.getOperand(2); Type type = op.getType(); - Value mult = b.create(type, operandA, operandB); - Value add = b.create(type, mult, operandC); + Value mult = arith::MulFOp::create(b, type, operandA, operandB); + Value add = arith::AddFOp::create(b, type, mult, operandC); rewriter.replaceOp(op, add); return success(); } @@ -236,11 +236,11 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); - Value gtCheck = b.create(arith::CmpFPredicate::OGT, operand, + Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand, fpFixedConvert); - Value incrValue = b.create(op->getLoc(), gtCheck, one, zero); + Value incrValue = arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero); - Value ret = b.create(opType, fpFixedConvert, incrValue); + Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue); rewriter.replaceOp(op, ret); return success(); } @@ -258,8 +258,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, auto convertFPowItoPowf = [&]() -> LogicalResult { Value castPowerToFp = - rewriter.create(op.getLoc(), baseType, power); - Value res = rewriter.create(op.getLoc(), baseType, base, + arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power); + Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base, castPowerToFp); rewriter.replaceOp(op, res); return success(); @@ -281,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, while (absPower > 0) { if (absPower & 1) - res = b.create(baseType, base, res); + res = arith::MulFOp::create(b, baseType, base, res); absPower >>= 1; - base = b.create(baseType, base, base); + base = arith::MulFOp::create(b, baseType, base, base); } // Make sure not to introduce UB in case of negative power. @@ -303,13 +303,13 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, createFloatConst(op->getLoc(), baseType, APFloat::getInf(sem, /*Negative=*/true), rewriter); Value zeroEqCheck = - b.create(arith::CmpFPredicate::OEQ, res, zero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero); Value negZeroEqCheck = - b.create(arith::CmpFPredicate::OEQ, res, negZero); - res = b.create(baseType, one, res); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero); + res = arith::DivFOp::create(b, baseType, one, res); res = - b.create(op->getLoc(), zeroEqCheck, posInfinity, res); - res = b.create(op->getLoc(), negZeroEqCheck, negInfinity, + arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res); + res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity, res); } @@ -331,7 +331,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { cast(getElementTypeOrSelf(typeB)).getFloatSemantics(); APFloat valueB(sem); auto mulf = [&](Value x, Value y) -> Value { - return b.create(x, y); + return arith::MulFOp::create(b, x, y); }; if (matchPattern(operandB, m_ConstantFloat(&valueB))) { if (valueB.isZero()) { @@ -348,19 +348,19 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { if (valueB.isExactlyValue(-1.0)) { // a^(-1) -> 1 / a Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); - Value div = b.create(one, operandA); + Value div = arith::DivFOp::create(b, one, operandA); rewriter.replaceOp(op, div); return success(); } if (valueB.isExactlyValue(0.5)) { // a^(1/2) -> sqrt(a) - Value sqrt = b.create(operandA); + Value sqrt = math::SqrtOp::create(b, operandA); rewriter.replaceOp(op, sqrt); return success(); } if (valueB.isExactlyValue(-0.5)) { // a^(-1/2) -> 1 / sqrt(a) - Value rsqrt = b.create(operandA); + Value rsqrt = math::RsqrtOp::create(b, operandA); rewriter.replaceOp(op, rsqrt); return success(); } @@ -373,7 +373,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { // a^(-2) -> 1 / (a * a) Value one = createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - Value div = b.create(one, mulf(operandA, operandA)); + Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA)); rewriter.replaceOp(op, div); return success(); } @@ -383,9 +383,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { } } - Value logA = b.create(operandA); - Value mult = b.create(operandB, logA); - Value expResult = b.create(mult); + Value logA = math::LogOp::create(b, operandA); + Value mult = arith::MulFOp::create(b, operandB, logA); + Value expResult = math::ExpOp::create(b, mult); rewriter.replaceOp(op, expResult); return success(); } @@ -400,8 +400,8 @@ static LogicalResult convertExp2fOp(math::Exp2Op op, Value operand = op.getOperand(); Type opType = operand.getType(); Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); - Value mult = b.create(opType, operand, ln2); - Value exp = b.create(op->getLoc(), mult); + Value mult = arith::MulFOp::create(b, opType, operand, ln2); + Value exp = math::ExpOp::create(b, op->getLoc(), mult); rewriter.replaceOp(op, exp); return success(); } @@ -427,8 +427,8 @@ static LogicalResult convertRoundOp(math::RoundOp op, Value c127 = createIntConst(loc, i32Ty, 127, b); Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); - Value incrValue = b.create(half, operand); - Value add = b.create(opType, operand, incrValue); + Value incrValue = math::CopySignOp::create(b, half, operand); + Value add = arith::AddFOp::create(b, opType, operand, incrValue); Value fpFixedConvert = createTruncatedFPValue(add, b); // There are three cases where adding 0.5 to the value and truncating by @@ -451,14 +451,14 @@ static LogicalResult convertRoundOp(math::RoundOp op, // i64 leading to wrong outputs. // // All three cases satisfy the property `biasedExp >= 23`. - Value operandBitcast = b.create(i32Ty, operand); - Value operandExp = b.create( - b.create(operandBitcast, c23), expMask); - Value operandBiasedExp = b.create(operandExp, c127); + Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand); + Value operandExp = arith::AndIOp::create(b, + arith::ShRUIOp::create(b, operandBitcast, c23), expMask); + Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127); Value isSpecialValOrLargeVal = - b.create(arith::CmpIPredicate::sge, operandBiasedExp, c23); + arith::CmpIOp::create(b, arith::CmpIPredicate::sge, operandBiasedExp, c23); - Value result = b.create(isSpecialValOrLargeVal, operand, + Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, fpFixedConvert); rewriter.replaceOp(op, result); return success(); @@ -490,20 +490,20 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); Value pred = - rewriter.create(loc, arith::CmpIPredicate::ule, x, mask); - Value add = rewriter.create(loc, count, bits); - Value shift = rewriter.create(loc, x, bits); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule, x, mask); + Value add = arith::AddIOp::create(rewriter, loc, count, bits); + Value shift = arith::ShLIOp::create(rewriter, loc, x, bits); - x = rewriter.create(loc, pred, shift, x); - count = rewriter.create(loc, pred, add, count); + x = arith::SelectOp::create(rewriter, loc, pred, shift, x); + count = arith::SelectOp::create(rewriter, loc, pred, add, count); } Value zero = createIntConst(loc, operandTy, 0, rewriter); - Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, + Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, operand, zero); Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); - Value sel = rewriter.create(loc, pred, bwval, count); + Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count); rewriter.replaceOp(op, sel); return success(); } @@ -550,29 +550,29 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b); Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b); - Value operandBitcast = b.create(iTy, operand); - Value round = b.create(operand); - Value roundBitcast = b.create(iTy, round); + Value operandBitcast = arith::BitcastOp::create(b, iTy, operand); + Value round = math::RoundOp::create(b, operand); + Value roundBitcast = arith::BitcastOp::create(b, iTy, round); // Get biased exponents for operand and round(operand) - Value operandExp = b.create( - b.create(operandBitcast, c23), expMask); - Value operandBiasedExp = b.create(operandExp, c127); - Value roundExp = b.create( - b.create(roundBitcast, c23), expMask); - Value roundBiasedExp = b.create(roundExp, c127); + Value operandExp = arith::AndIOp::create(b, + arith::ShRUIOp::create(b, operandBitcast, c23), expMask); + Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127); + Value roundExp = arith::AndIOp::create(b, + arith::ShRUIOp::create(b, roundBitcast, c23), expMask); + Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127); auto safeShiftRight = [&](Value x, Value shift) -> Value { // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior - Value clampedShift = b.create(shift, c0); - clampedShift = b.create(clampedShift, c31); - return b.create(x, clampedShift); + Value clampedShift = arith::MaxSIOp::create(b, shift, c0); + clampedShift = arith::MinSIOp::create(b, clampedShift, c31); + return arith::ShRUIOp::create(b, x, clampedShift); }; auto maskMantissa = [&](Value mantissa, Value mantissaMaskRightShift) -> Value { Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); - return b.create(mantissa, shiftedMantissaMask); + return arith::AndIOp::create(b, mantissa, shiftedMantissaMask); }; // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring @@ -590,13 +590,13 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, // `biasedExp > 23`, so they get treated as large numbers with no room for // decimals, which are always even. Value roundBiasedExpEq0 = - b.create(arith::CmpIPredicate::eq, roundBiasedExp, c0); - Value roundBiasedExpMinus1 = b.create(roundBiasedExp, c1); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0); + Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1); Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); - Value roundIsNotEvenOrSpecialVal = b.create( + Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0); roundIsNotEvenOrSpecialVal = - b.create(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); + arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive // integers if the bit at index `biasedExp` starting from the left in the @@ -605,37 +605,37 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, // so these are handled separately. In particular, if `biasedExp == -1`, the // value is halfway if the entire mantissa is zero. - Value operandBiasedExpEqNeg1 = b.create( + Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); - Value expectedOperandMaskedMantissa = b.create( + Value expectedOperandMaskedMantissa = arith::SelectOp::create(b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); Value operandIsHalfway = - b.create(arith::CmpIPredicate::eq, operandMaskedMantissa, + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa, expectedOperandMaskedMantissa); // Ensure `biasedExp` is in the valid range for half values. - Value operandBiasedExpGeNeg1 = b.create( + Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); Value operandBiasedExpLt23 = - b.create(arith::CmpIPredicate::slt, operandBiasedExp, c23); + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, operandBiasedExp, c23); operandIsHalfway = - b.create(operandIsHalfway, operandBiasedExpLt23); + arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23); operandIsHalfway = - b.create(operandIsHalfway, operandBiasedExpGeNeg1); + arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1); // Adjust rounded operand with `round(operand) - sign(operand)` to correct the // case where `round` rounded in the opposite direction of `roundeven`. - Value sign = b.create(c1Float, operand); - Value roundShifted = b.create(round, sign); + Value sign = math::CopySignOp::create(b, c1Float, operand); + Value roundShifted = arith::SubFOp::create(b, round, sign); // If the rounded value is even or a special value, we default to the behavior // of `math.round`. Value needsShift = - b.create(roundIsNotEvenOrSpecialVal, operandIsHalfway); - Value result = b.create(needsShift, roundShifted, round); + arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway); + Value result = arith::SelectOp::create(b, needsShift, roundShifted, round); // The `x - sign` adjustment does not preserve the sign when we are adjusting // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is // rounded to -0.0. - result = b.create(result, operand); + result = math::CopySignOp::create(b, result, operand); rewriter.replaceOp(op, result); return success(); } @@ -657,7 +657,7 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, Location loc = op->getLoc(); auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); - auto sqrtOp = rewriter.create(loc, operand); + auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand); rewriter.replaceOpWithNewOp(op, constOneFloat, sqrtOp); return success(); } diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp index a570ed5118ef0..9d6ad613fc945 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp @@ -73,7 +73,7 @@ void mlir::math::populateExtendToSupportedTypesTypeConverter( }); typeConverter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); + auto extFOp = arith::ExtFOp::create(b, loc, target, input); extFOp.setFastmath(arith::FastMathFlags::contract); return extFOp; }); @@ -104,7 +104,7 @@ LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite( for (auto [result, newType, origType] : llvm::zip_equal( results, (*legalized)->getResultTypes(), op->getResultTypes())) { if (newType != origType) { - auto truncFOp = rewriter.create(loc, origType, result); + auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType, result); truncFOp.setFastmath(arith::FastMathFlags::contract); result = truncFOp.getResult(); } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index a26e380232a91..2e27aba30575b 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -73,7 +73,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value, std::optional shape) { assert(!isa(value.getType()) && "must be scalar value"); auto type = broadcast(value.getType(), shape); - return shape ? builder.create(type, value) : value; + return shape ? BroadcastOp::create(builder, type, value) : value; } //----------------------------------------------------------------------------// @@ -131,7 +131,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, auto eltType = cast(operand.getType()).getElementType(); auto expandedType = VectorType::get(expandedShape, eltType); expandedOperands[i] = - builder.create(expandedType, operand); + vector::ShapeCastOp::create(builder, expandedType, operand); } } @@ -149,7 +149,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, SmallVector extracted(expandedOperands.size()); for (const auto &tuple : llvm::enumerate(expandedOperands)) extracted[tuple.index()] = - builder.create(tuple.value(), offsets); + vector::ExtractOp::create(builder, tuple.value(), offsets); results[i] = compute(extracted); } @@ -157,15 +157,15 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Stitch results together into one large vector. Type resultEltType = cast(results[0].getType()).getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); - Value result = builder.create( + Value result = arith::ConstantOp::create(builder, resultExpandedType, builder.getZeroAttr(resultExpandedType)); for (int64_t i = 0; i < maxIndex; ++i) - result = builder.create(results[i], result, + result = vector::InsertOp::create(builder, results[i], result, delinearize(i, strides)); // Reshape back to the original vector shape. - return builder.create( + return vector::ShapeCastOp::create(builder, VectorType::get(inputShape, resultEltType), result); } @@ -174,28 +174,28 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, //----------------------------------------------------------------------------// static Value boolCst(ImplicitLocOpBuilder &builder, bool value) { - return builder.create(builder.getBoolAttr(value)); + return arith::ConstantOp::create(builder, builder.getBoolAttr(value)); } static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType) { assert((elementType.isF16() || elementType.isF32()) && "x must be f16 or f32 type."); - return builder.create( + return arith::ConstantOp::create(builder, builder.getFloatAttr(elementType, value)); } static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { - return builder.create(builder.getF32FloatAttr(value)); + return arith::ConstantOp::create(builder, builder.getF32FloatAttr(value)); } static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { - return builder.create(builder.getI32IntegerAttr(value)); + return arith::ConstantOp::create(builder, builder.getI32IntegerAttr(value)); } static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { Value i32Value = i32Cst(builder, static_cast(bits)); - return builder.create(builder.getF32Type(), i32Value); + return arith::BitcastOp::create(builder, builder.getF32Type(), i32Value); } //----------------------------------------------------------------------------// @@ -204,15 +204,15 @@ static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { // Return the minimum of the two values or NaN if value is NaN static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { - return builder.create( - builder.create(arith::CmpFPredicate::ULT, value, bound), + return arith::SelectOp::create(builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound), value, bound); } // Return the maximum of the two values or NaN if value is NaN static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { - return builder.create( - builder.create(arith::CmpFPredicate::UGT, value, bound), + return arith::SelectOp::create(builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound), value, bound); } @@ -242,24 +242,24 @@ static std::pair frexp(ImplicitLocOpBuilder &builder, Value arg, Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); // Bitcast to i32 for bitwise operations. - Value i32Half = builder.create(i32, cstHalf); - Value i32InvMantMask = builder.create(i32, cstInvMantMask); - Value i32Arg = builder.create(i32Vec, arg); + Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf); + Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask); + Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg); // Compute normalized fraction. - Value tmp0 = builder.create(i32Arg, bcast(i32InvMantMask)); - Value tmp1 = builder.create(tmp0, bcast(i32Half)); - Value normalizedFraction = builder.create(f32Vec, tmp1); + Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask)); + Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half)); + Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1); // Compute exponent. - Value arg0 = isPositive ? arg : builder.create(arg); - Value biasedExponentBits = builder.create( - builder.create(i32Vec, arg0), + Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg); + Value biasedExponentBits = arith::ShRUIOp::create(builder, + arith::BitcastOp::create(builder, i32Vec, arg0), bcast(i32Cst(builder, 23))); Value biasedExponent = - builder.create(f32Vec, biasedExponentBits); + arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits); Value exponent = - builder.create(biasedExponent, bcast(cst126f)); + arith::SubFOp::create(builder, biasedExponent, bcast(cst126f)); return {normalizedFraction, exponent}; } @@ -279,10 +279,10 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { // Set the exponent bias to zero. auto bias = bcast(i32Cst(builder, 127)); - Value biasedArg = builder.create(arg, bias); + Value biasedArg = arith::AddIOp::create(builder, arg, bias); Value exp2ValueInt = - builder.create(biasedArg, exponetBitLocation); - Value exp2ValueF32 = builder.create(f32Vec, exp2ValueInt); + arith::ShLIOp::create(builder, biasedArg, exponetBitLocation); + Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt); return exp2ValueF32; } @@ -301,10 +301,10 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, if (coeffs.size() == 1) return coeffs[0]; - Value res = builder.create(x, coeffs[coeffs.size() - 1], + Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1], coeffs[coeffs.size() - 2]); for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { - res = builder.create(x, res, coeffs[i]); + res = math::FmaOp::create(builder, x, res, coeffs[i]); } return res; } @@ -344,9 +344,9 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); SmallVector operands; for (auto operand : op->getOperands()) - operands.push_back(rewriter.create(loc, newType, operand)); + operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand)); auto result = - rewriter.create(loc, TypeRange{newType}, operands, op->getAttrs()); + T::create(rewriter, loc, TypeRange{newType}, operands, op->getAttrs()); rewriter.replaceOpWithNewOp(op, origType, result); return success(); } @@ -394,18 +394,18 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, std::optional shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - Value abs = builder.create(operand); + Value abs = math::AbsFOp::create(builder, operand); auto one = broadcast(builder, f32Cst(builder, 1.0), shape); // When 0.66 < x <= 2.41 we do (x-1) / (x+1): auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape); Value cmp2 = - builder.create(arith::CmpFPredicate::OGT, abs, twoThirds); - Value addone = builder.create(abs, one); - Value subone = builder.create(abs, one); - Value xnum = builder.create(cmp2, subone, abs); - Value xden = builder.create(cmp2, addone, one); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds); + Value addone = arith::AddFOp::create(builder, abs, one); + Value subone = arith::SubFOp::create(builder, abs, one); + Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs); + Value xden = arith::SelectOp::create(builder, cmp2, addone, one); auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); @@ -414,12 +414,12 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, // Break into the <= 0.66 or > 2.41 we do x or 1/x: auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880)); Value cmp1 = - builder.create(arith::CmpFPredicate::OGT, abs, tan3pio8); - xnum = builder.create(cmp1, one, xnum); - xden = builder.create(cmp1, abs, xden); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8); + xnum = arith::SelectOp::create(builder, cmp1, one, xnum); + xden = arith::SelectOp::create(builder, cmp1, abs, xden); - Value x = builder.create(xnum, xden); - Value xx = builder.create(x, x); + Value x = arith::DivFOp::create(builder, xnum, xden); + Value xx = arith::MulFOp::create(builder, x, x); // Perform the Taylor series approximation for atan over the range // [0.0, 0.66]. @@ -436,31 +436,31 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, // Apply the polynomial approximation for the numerator: Value n = p0; - n = builder.create(xx, n, p1); - n = builder.create(xx, n, p2); - n = builder.create(xx, n, p3); - n = builder.create(xx, n, p4); - n = builder.create(n, xx); + n = math::FmaOp::create(builder, xx, n, p1); + n = math::FmaOp::create(builder, xx, n, p2); + n = math::FmaOp::create(builder, xx, n, p3); + n = math::FmaOp::create(builder, xx, n, p4); + n = arith::MulFOp::create(builder, n, xx); // Apply the polynomial approximation for the denominator: Value d = q0; - d = builder.create(xx, d, q1); - d = builder.create(xx, d, q2); - d = builder.create(xx, d, q3); - d = builder.create(xx, d, q4); + d = math::FmaOp::create(builder, xx, d, q1); + d = math::FmaOp::create(builder, xx, d, q2); + d = math::FmaOp::create(builder, xx, d, q3); + d = math::FmaOp::create(builder, xx, d, q4); // Compute approximation of theta: - Value ans0 = builder.create(n, d); - ans0 = builder.create(ans0, x, x); + Value ans0 = arith::DivFOp::create(builder, n, d); + ans0 = math::FmaOp::create(builder, ans0, x, x); // Correct for the input mapping's angles: Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4)); - Value ans2 = builder.create(mpi4, ans0); - Value ans = builder.create(cmp2, ans2, ans0); + Value ans2 = arith::AddFOp::create(builder, mpi4, ans0); + Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0); Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2)); - Value ans1 = builder.create(mpi2, ans0); - ans = builder.create(cmp1, ans1, ans); + Value ans1 = arith::SubFOp::create(builder, mpi2, ans0); + ans = arith::SelectOp::create(builder, cmp1, ans1, ans); // Correct for signing of the input. rewriter.replaceOpWithNewOp(op, ans, operand); @@ -493,44 +493,44 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op, std::optional shape = vectorShape(op.getResult()); // Compute atan in the valid range. - auto div = builder.create(y, x); - auto atan = builder.create(div); + auto div = arith::DivFOp::create(builder, y, x); + auto atan = math::AtanOp::create(builder, div); // Determine what the atan would be for a 180 degree rotation. auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); - auto addPi = builder.create(atan, pi); - auto subPi = builder.create(atan, pi); + auto addPi = arith::AddFOp::create(builder, atan, pi); + auto subPi = arith::SubFOp::create(builder, atan, pi); auto atanGt = - builder.create(arith::CmpFPredicate::OGT, atan, zero); - auto flippedAtan = builder.create(atanGt, subPi, addPi); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero); + auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi); // Determine whether to directly use atan or use the 180 degree flip - auto xGt = builder.create(arith::CmpFPredicate::OGT, x, zero); - Value result = builder.create(xGt, atan, flippedAtan); + auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero); + Value result = arith::SelectOp::create(builder, xGt, atan, flippedAtan); // Handle x = 0, y > 0 Value xZero = - builder.create(arith::CmpFPredicate::OEQ, x, zero); - Value yGt = builder.create(arith::CmpFPredicate::OGT, y, zero); - Value isHalfPi = builder.create(xZero, yGt); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero); + Value yGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero); + Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt); auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); - result = builder.create(isHalfPi, halfPi, result); + result = arith::SelectOp::create(builder, isHalfPi, halfPi, result); // Handle x = 0, y < 0 - Value yLt = builder.create(arith::CmpFPredicate::OLT, y, zero); - Value isNegativeHalfPiPi = builder.create(xZero, yLt); + Value yLt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero); + Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt); auto negativeHalfPiPi = broadcast(builder, f32Cst(builder, -1.57079632679f), shape); - result = builder.create(isNegativeHalfPiPi, negativeHalfPiPi, + result = arith::SelectOp::create(builder, isNegativeHalfPiPi, negativeHalfPiPi, result); // Handle x = 0, y = 0; Value yZero = - builder.create(arith::CmpFPredicate::OEQ, y, zero); - Value isNan = builder.create(xZero, yZero); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero); + Value isNan = arith::AndIOp::create(builder, xZero, yZero); Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); - result = builder.create(isNan, cstNan, result); + result = arith::SelectOp::create(builder, isNan, cstNan, result); rewriter.replaceOp(op, result); return success(); @@ -570,8 +570,8 @@ TanhApproximation::matchAndRewrite(math::TanhOp op, // Mask for tiny values that are approximated with `operand`. Value tiny = bcast(f32Cst(builder, 0.0004f)); - Value tinyMask = builder.create( - arith::CmpFPredicate::OLT, builder.create(op.getOperand()), + Value tinyMask = arith::CmpFOp::create(builder, + arith::CmpFPredicate::OLT, math::AbsFOp::create(builder, op.getOperand()), tiny); // The monomial coefficients of the numerator polynomial (odd). @@ -590,25 +590,25 @@ TanhApproximation::matchAndRewrite(math::TanhOp op, Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); // Since the polynomials are odd/even, we need x^2. - Value x2 = builder.create(x, x); + Value x2 = arith::MulFOp::create(builder, x, x); // Evaluate the numerator polynomial p. - Value p = builder.create(x2, alpha13, alpha11); - p = builder.create(x2, p, alpha9); - p = builder.create(x2, p, alpha7); - p = builder.create(x2, p, alpha5); - p = builder.create(x2, p, alpha3); - p = builder.create(x2, p, alpha1); - p = builder.create(x, p); + Value p = math::FmaOp::create(builder, x2, alpha13, alpha11); + p = math::FmaOp::create(builder, x2, p, alpha9); + p = math::FmaOp::create(builder, x2, p, alpha7); + p = math::FmaOp::create(builder, x2, p, alpha5); + p = math::FmaOp::create(builder, x2, p, alpha3); + p = math::FmaOp::create(builder, x2, p, alpha1); + p = arith::MulFOp::create(builder, x, p); // Evaluate the denominator polynomial q. - Value q = builder.create(x2, beta6, beta4); - q = builder.create(x2, q, beta2); - q = builder.create(x2, q, beta0); + Value q = math::FmaOp::create(builder, x2, beta6, beta4); + q = math::FmaOp::create(builder, x2, q, beta2); + q = math::FmaOp::create(builder, x2, q, beta0); // Divide the numerator by the denominator. - Value res = builder.create( - tinyMask, x, builder.create(p, q)); + Value res = arith::SelectOp::create(builder, + tinyMask, x, arith::DivFOp::create(builder, p, q)); rewriter.replaceOp(op, res); @@ -691,57 +691,57 @@ LogApproximationBase::logMatchAndRewrite(Op op, PatternRewriter &rewriter, // e -= 1; // x = x + x - 1.0; // } else { x = x - 1.0; } - Value mask = builder.create(arith::CmpFPredicate::OLT, x, + Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, cstCephesSQRTHF); - Value tmp = builder.create(mask, x, cstZero); + Value tmp = arith::SelectOp::create(builder, mask, x, cstZero); - x = builder.create(x, cstOne); - e = builder.create( - e, builder.create(mask, cstOne, cstZero)); - x = builder.create(x, tmp); + x = arith::SubFOp::create(builder, x, cstOne); + e = arith::SubFOp::create(builder, + e, arith::SelectOp::create(builder, mask, cstOne, cstZero)); + x = arith::AddFOp::create(builder, x, tmp); - Value x2 = builder.create(x, x); - Value x3 = builder.create(x2, x); + Value x2 = arith::MulFOp::create(builder, x, x); + Value x3 = arith::MulFOp::create(builder, x2, x); // Evaluate the polynomial approximant of degree 8 in three parts. Value y0, y1, y2; - y0 = builder.create(cstCephesLogP0, x, cstCephesLogP1); - y1 = builder.create(cstCephesLogP3, x, cstCephesLogP4); - y2 = builder.create(cstCephesLogP6, x, cstCephesLogP7); - y0 = builder.create(y0, x, cstCephesLogP2); - y1 = builder.create(y1, x, cstCephesLogP5); - y2 = builder.create(y2, x, cstCephesLogP8); - y0 = builder.create(y0, x3, y1); - y0 = builder.create(y0, x3, y2); - y0 = builder.create(y0, x3); - - y0 = builder.create(cstNegHalf, x2, y0); - x = builder.create(x, y0); + y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1); + y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4); + y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7); + y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2); + y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5); + y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8); + y0 = math::FmaOp::create(builder, y0, x3, y1); + y0 = math::FmaOp::create(builder, y0, x3, y2); + y0 = arith::MulFOp::create(builder, y0, x3); + + y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0); + x = arith::AddFOp::create(builder, x, y0); if (base2) { Value cstLog2e = bcast(f32Cst(builder, static_cast(LOG2E_VALUE))); - x = builder.create(x, cstLog2e, e); + x = math::FmaOp::create(builder, x, cstLog2e, e); } else { Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); - x = builder.create(e, cstLn2, x); + x = math::FmaOp::create(builder, e, cstLn2, x); } - Value invalidMask = builder.create(arith::CmpFPredicate::ULT, + Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, op.getOperand(), cstZero); - Value zeroMask = builder.create(arith::CmpFPredicate::OEQ, + Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, op.getOperand(), cstZero); - Value posInfMask = builder.create(arith::CmpFPredicate::OEQ, + Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, op.getOperand(), cstPosInf); // Filter out invalid values: // • x == 0 -> -INF // • x < 0 -> NAN // • x == +INF -> +INF - Value aproximation = builder.create( + Value aproximation = arith::SelectOp::create(builder, zeroMask, cstMinusInf, - builder.create( + arith::SelectOp::create(builder, invalidMask, cstNan, - builder.create(posInfMask, cstPosInf, x))); + arith::SelectOp::create(builder, posInfMask, cstPosInf, x))); rewriter.replaceOp(op, aproximation); @@ -806,17 +806,17 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op, // "logLarge" below. Value cstOne = bcast(f32Cst(builder, 1.0f)); Value x = op.getOperand(); - Value u = builder.create(x, cstOne); + Value u = arith::AddFOp::create(builder, x, cstOne); Value uSmall = - builder.create(arith::CmpFPredicate::OEQ, u, cstOne); - Value logU = builder.create(u); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne); + Value logU = math::LogOp::create(builder, u); Value uInf = - builder.create(arith::CmpFPredicate::OEQ, u, logU); - Value logLarge = builder.create( - x, builder.create( - logU, builder.create(u, cstOne))); - Value approximation = builder.create( - builder.create(uSmall, uInf), x, logLarge); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU); + Value logLarge = arith::MulFOp::create(builder, + x, arith::DivFOp::create(builder, + logU, arith::SubFOp::create(builder, u, cstOne))); + Value approximation = arith::SelectOp::create(builder, + arith::OrIOp::create(builder, uSmall, uInf), x, logLarge); rewriter.replaceOp(op, approximation); return success(); } @@ -854,27 +854,27 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, }; auto fma = [&](Value a, Value b, Value c) -> Value { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; auto sub = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::SubFOp::create(builder, a, b); }; - auto abs = [&](Value a) -> Value { return builder.create(a); }; + auto abs = [&](Value a) -> Value { return math::AbsFOp::create(builder, a); }; - auto sqrt = [&](Value a) -> Value { return builder.create(a); }; + auto sqrt = [&](Value a) -> Value { return math::SqrtOp::create(builder, a); }; auto scopy = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return math::CopySignOp::create(builder, a, b); }; auto sel = [&](Value a, Value b, Value c) -> Value { - return builder.create(a, b, c); + return arith::SelectOp::create(builder, a, b, c); }; Value abso = abs(operand); @@ -882,7 +882,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa)); Value gt = - builder.create(arith::CmpFPredicate::OGT, aa, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa, bcast(floatCst(builder, 0.5, elementType))); Value x = sel(gt, opp, abso); @@ -949,51 +949,51 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op, }; auto fma = [&](Value a, Value b, Value c) -> Value { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; - Value negOperand = builder.create(operand); + Value negOperand = arith::NegFOp::create(builder, operand); Value zero = bcast(floatCst(builder, 0.0, elementType)); Value half = bcast(floatCst(builder, 0.5, elementType)); Value negOne = bcast(floatCst(builder, -1.0, elementType)); Value selR = - builder.create(arith::CmpFPredicate::OGT, operand, zero); - Value r = builder.create(selR, negOperand, operand); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero); + Value r = arith::SelectOp::create(builder, selR, negOperand, operand); Value chkConst = bcast(floatCst(builder, -0.5625, elementType)); Value firstPred = - builder.create(arith::CmpFPredicate::OGT, r, chkConst); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst); Value trueVal = fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)), bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), - builder.create(r)); + math::AsinOp::create(builder, r)); - Value falseVal = builder.create(fma(half, r, half)); - falseVal = builder.create(falseVal); + Value falseVal = math::SqrtOp::create(builder, fma(half, r, half)); + falseVal = math::AsinOp::create(builder, falseVal); falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal); - r = builder.create(firstPred, trueVal, falseVal); + r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal); // Check whether the operand lies in between [-1.0, 0.0). Value greaterThanNegOne = - builder.create(arith::CmpFPredicate::OGE, operand, negOne); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGE, operand, negOne); Value lessThanZero = - builder.create(arith::CmpFPredicate::OLT, operand, zero); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero); Value betweenNegOneZero = - builder.create(greaterThanNegOne, lessThanZero); + arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero); trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)), bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), - builder.create(r)); + arith::NegFOp::create(builder, r)); Value finalVal = - builder.create(betweenNegOneZero, trueVal, r); + arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r); rewriter.replaceOp(op, finalVal); return success(); @@ -1076,9 +1076,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, bounds[2] = bcast(floatCst(builder, 3.75f, elementType)); Value isNegativeArg = - builder.create(arith::CmpFPredicate::OLT, operand, zero); - Value negArg = builder.create(operand); - Value x = builder.create(isNegativeArg, negArg, operand); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero); + Value negArg = arith::NegFOp::create(builder, operand); + Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand); Value offset = offsets[0]; Value p[polyDegree + 1]; @@ -1092,30 +1092,30 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, Value isLessThanBound[intervalsCount]; for (int j = 0; j < intervalsCount - 1; ++j) { isLessThanBound[j] = - builder.create(arith::CmpFPredicate::OLT, x, bounds[j]); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[j]); for (int i = 0; i <= polyDegree; ++i) { - p[i] = builder.create(isLessThanBound[j], p[i], + p[i] = arith::SelectOp::create(builder, isLessThanBound[j], p[i], pp[j + 1][i]); - q[i] = builder.create(isLessThanBound[j], q[i], + q[i] = arith::SelectOp::create(builder, isLessThanBound[j], q[i], qq[j + 1][i]); } - offset = builder.create(isLessThanBound[j], offset, + offset = arith::SelectOp::create(builder, isLessThanBound[j], offset, offsets[j + 1]); } - isLessThanBound[intervalsCount - 1] = builder.create( + isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); Value pPoly = makePolynomialCalculation(builder, p, x); Value qPoly = makePolynomialCalculation(builder, q, x); - Value rationalPoly = builder.create(pPoly, qPoly); - Value formula = builder.create(offset, rationalPoly); - formula = builder.create(isLessThanBound[intervalsCount - 1], + Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly); + Value formula = arith::AddFOp::create(builder, offset, rationalPoly); + formula = arith::SelectOp::create(builder, isLessThanBound[intervalsCount - 1], formula, one); // erf is odd function: erf(x) = -erf(-x). - Value negFormula = builder.create(formula); + Value negFormula = arith::NegFOp::create(builder, formula); Value res = - builder.create(isNegativeArg, negFormula, formula); + arith::SelectOp::create(builder, isNegativeArg, negFormula, formula); rewriter.replaceOp(op, res); @@ -1156,65 +1156,65 @@ ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op, Value posInf = bcast(floatCst(builder, INFINITY, et)); Value clampVal = bcast(floatCst(builder, 10.0546875f, et)); - Value a = builder.create(x); - Value p = builder.create(a, pos2); - Value r = builder.create(one, p); - Value q = builder.create(neg4, r, one); - Value t = builder.create(builder.create(q, one), + Value a = math::AbsFOp::create(builder, x); + Value p = arith::AddFOp::create(builder, a, pos2); + Value r = arith::DivFOp::create(builder, one, p); + Value q = math::FmaOp::create(builder, neg4, r, one); + Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one), neg2, a); - Value e = builder.create(builder.create(a), q, t); - q = builder.create(r, e, q); + Value e = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t); + q = math::FmaOp::create(builder, r, e, q); p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4 Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3 - p = builder.create(p, q, c1); + p = math::FmaOp::create(builder, p, q, c1); Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3 - p = builder.create(p, q, c2); + p = math::FmaOp::create(builder, p, q, c2); Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3 - p = builder.create(p, q, c3); + p = math::FmaOp::create(builder, p, q, c3); Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3 - p = builder.create(p, q, c4); + p = math::FmaOp::create(builder, p, q, c4); Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2 - p = builder.create(p, q, c5); + p = math::FmaOp::create(builder, p, q, c5); Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1 - p = builder.create(p, q, c6); + p = math::FmaOp::create(builder, p, q, c6); Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1 - p = builder.create(p, q, c7); + p = math::FmaOp::create(builder, p, q, c7); Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2 - p = builder.create(p, q, c8); + p = math::FmaOp::create(builder, p, q, c8); Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1 - p = builder.create(p, q, c9); - - Value d = builder.create(pos2, a, one); - r = builder.create(one, d); - q = builder.create(p, r, r); - Value negfa = builder.create(a); - Value fmaqah = builder.create(q, negfa, onehalf); - Value psubq = builder.create(p, q); - e = builder.create(fmaqah, pos2, psubq); - r = builder.create(e, r, q); - - Value s = builder.create(a, a); - e = builder.create(builder.create(s)); - - t = builder.create(builder.create(a), a, s); - r = builder.create( + p = math::FmaOp::create(builder, p, q, c9); + + Value d = math::FmaOp::create(builder, pos2, a, one); + r = arith::DivFOp::create(builder, one, d); + q = math::FmaOp::create(builder, p, r, r); + Value negfa = arith::NegFOp::create(builder, a); + Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf); + Value psubq = arith::SubFOp::create(builder, p, q); + e = math::FmaOp::create(builder, fmaqah, pos2, psubq); + r = math::FmaOp::create(builder, e, r, q); + + Value s = arith::MulFOp::create(builder, a, a); + e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s)); + + t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s); + r = math::FmaOp::create(builder, r, e, - builder.create(builder.create(r, e), t)); + arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t)); - Value isNotLessThanInf = builder.create( - builder.create(arith::CmpFPredicate::OLT, a, posInf), + Value isNotLessThanInf = arith::XOrIOp::create(builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf), trueValue); - r = builder.create(isNotLessThanInf, - builder.create(x, x), r); + r = arith::SelectOp::create(builder, isNotLessThanInf, + arith::AddFOp::create(builder, x, x), r); Value isGreaterThanClamp = - builder.create(arith::CmpFPredicate::OGT, a, clampVal); - r = builder.create(isGreaterThanClamp, zero, r); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal); + r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r); Value isNegative = - builder.create(arith::CmpFPredicate::OLT, x, zero); - r = builder.create( - isNegative, builder.create(pos2, r), r); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero); + r = arith::SelectOp::create(builder, + isNegative, arith::SubFOp::create(builder, pos2, r), r); rewriter.replaceOp(op, r); return success(); @@ -1236,8 +1236,8 @@ Value clampWithNormals(ImplicitLocOpBuilder &builder, }; auto selectCmp = [&builder](auto pred, Value value, Value bound) { - return builder.create( - builder.create(pred, value, bound), value, bound); + return arith::SelectOp::create(builder, + arith::CmpFOp::create(builder, pred, value, bound), value, bound); }; // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. @@ -1269,17 +1269,17 @@ ExpApproximation::matchAndRewrite(math::ExpOp op, ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto add = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::AddFOp::create(builder, a, b); }; auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); }; - auto floor = [&](Value a) { return builder.create(a); }; + auto floor = [&](Value a) { return math::FloorOp::create(builder, a); }; auto fmla = [&](Value a, Value b, Value c) { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; // Polynomial approximation from Cephes. @@ -1383,7 +1383,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op, // Convert n' to an i32. This is safe because we clamped it above. auto i32Vec = broadcast(builder.getI32Type(), shape); - Value nI32 = builder.create(i32Vec, n); + Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n); // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. Value pow2 = exp2I32(builder, nI32); @@ -1431,26 +1431,26 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, Value cstOne = bcast(f32Cst(builder, 1.0f)); Value cstNegOne = bcast(f32Cst(builder, -1.0f)); Value x = op.getOperand(); - Value u = builder.create(x); + Value u = math::ExpOp::create(builder, x); Value uEqOneOrNaN = - builder.create(arith::CmpFPredicate::UEQ, u, cstOne); - Value uMinusOne = builder.create(u, cstOne); - Value uMinusOneEqNegOne = builder.create( + arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne); + Value uMinusOne = arith::SubFOp::create(builder, u, cstOne); + Value uMinusOneEqNegOne = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); // logU = log(u) ~= x - Value logU = builder.create(u); + Value logU = math::LogOp::create(builder, u); // Detect exp(x) = +inf; written this way to avoid having to form +inf. Value isInf = - builder.create(arith::CmpFPredicate::OEQ, logU, u); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u); // (u - 1) * (x / ~x) - Value expm1 = builder.create( - uMinusOne, builder.create(x, logU)); - expm1 = builder.create(isInf, u, expm1); - Value approximation = builder.create( + Value expm1 = arith::MulFOp::create(builder, + uMinusOne, arith::DivFOp::create(builder, x, logU)); + expm1 = arith::SelectOp::create(builder, isInf, u, expm1); + Value approximation = arith::SelectOp::create(builder, uEqOneOrNaN, x, - builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); + arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1)); rewriter.replaceOp(op, approximation); return success(); } @@ -1495,40 +1495,40 @@ LogicalResult SinAndCosApproximation::matchAndRewrite( return broadcast(builder, value, shape); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; auto sub = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::SubFOp::create(builder, a, b); }; - auto floor = [&](Value a) { return builder.create(a); }; + auto floor = [&](Value a) { return math::FloorOp::create(builder, a); }; auto i32Vec = broadcast(builder.getI32Type(), shape); auto fPToSingedInteger = [&](Value a) -> Value { - return builder.create(i32Vec, a); + return arith::FPToSIOp::create(builder, i32Vec, a); }; auto modulo4 = [&](Value a) -> Value { - return builder.create(a, bcast(i32Cst(builder, 3))); + return arith::AndIOp::create(builder, a, bcast(i32Cst(builder, 3))); }; auto isEqualTo = [&](Value a, Value b) -> Value { - return builder.create(arith::CmpIPredicate::eq, a, b); + return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a, b); }; auto isGreaterThan = [&](Value a, Value b) -> Value { - return builder.create(arith::CmpIPredicate::sgt, a, b); + return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a, b); }; auto select = [&](Value cond, Value t, Value f) -> Value { - return builder.create(cond, t, f); + return arith::SelectOp::create(builder, cond, t, f); }; auto fmla = [&](Value a, Value b, Value c) { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto bitwiseOr = [&](Value a, Value b) { - return builder.create(a, b); + return arith::OrIOp::create(builder, a, b); }; Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI)); @@ -1625,7 +1625,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op, intTy = broadcast(intTy, shape); auto bconst = [&](TypedAttr attr) -> Value { - Value value = b.create(attr); + Value value = arith::ConstantOp::create(b, attr); return broadcast(b, value, shape); }; @@ -1642,44 +1642,44 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op, // union {int ix; float x;}; // x = x0; // ix = ix/4 + ix/16; - Value absValue = b.create(operand); - Value intValue = b.create(intTy, absValue); - Value divideBy4 = b.create(intValue, intTwo); - Value divideBy16 = b.create(intValue, intFour); - intValue = b.create(divideBy4, divideBy16); + Value absValue = math::AbsFOp::create(b, operand); + Value intValue = arith::BitcastOp::create(b, intTy, absValue); + Value divideBy4 = arith::ShRSIOp::create(b, intValue, intTwo); + Value divideBy16 = arith::ShRSIOp::create(b, intValue, intFour); + intValue = arith::AddIOp::create(b, divideBy4, divideBy16); // ix = ix + ix/16; - divideBy16 = b.create(intValue, intFour); - intValue = b.create(intValue, divideBy16); + divideBy16 = arith::ShRSIOp::create(b, intValue, intFour); + intValue = arith::AddIOp::create(b, intValue, divideBy16); // ix = ix + ix/256; - Value divideBy256 = b.create(intValue, intEight); - intValue = b.create(intValue, divideBy256); + Value divideBy256 = arith::ShRSIOp::create(b, intValue, intEight); + intValue = arith::AddIOp::create(b, intValue, divideBy256); // ix = 0x2a5137a0 + ix; - intValue = b.create(intValue, intMagic); + intValue = arith::AddIOp::create(b, intValue, intMagic); // Perform one newtons step: // x = 0.33333333f*(2.0f*x + x0/(x*x)); - Value floatValue = b.create(floatTy, intValue); - Value squared = b.create(floatValue, floatValue); - Value mulTwo = b.create(floatValue, fpTwo); - Value divSquared = b.create(absValue, squared); - floatValue = b.create(mulTwo, divSquared); - floatValue = b.create(floatValue, fpThird); + Value floatValue = arith::BitcastOp::create(b, floatTy, intValue); + Value squared = arith::MulFOp::create(b, floatValue, floatValue); + Value mulTwo = arith::MulFOp::create(b, floatValue, fpTwo); + Value divSquared = arith::DivFOp::create(b, absValue, squared); + floatValue = arith::AddFOp::create(b, mulTwo, divSquared); + floatValue = arith::MulFOp::create(b, floatValue, fpThird); // x = 0.33333333f*(2.0f*x + x0/(x*x)); - squared = b.create(floatValue, floatValue); - mulTwo = b.create(floatValue, fpTwo); - divSquared = b.create(absValue, squared); - floatValue = b.create(mulTwo, divSquared); - floatValue = b.create(floatValue, fpThird); + squared = arith::MulFOp::create(b, floatValue, floatValue); + mulTwo = arith::MulFOp::create(b, floatValue, fpTwo); + divSquared = arith::DivFOp::create(b, absValue, squared); + floatValue = arith::AddFOp::create(b, mulTwo, divSquared); + floatValue = arith::MulFOp::create(b, floatValue, fpThird); // Check for zero and restore sign. Value isZero = - b.create(arith::CmpFPredicate::OEQ, absValue, fpZero); - floatValue = b.create(isZero, fpZero, floatValue); - floatValue = b.create(floatValue, operand); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absValue, fpZero); + floatValue = arith::SelectOp::create(b, isZero, fpZero, floatValue); + floatValue = math::CopySignOp::create(b, floatValue, operand); rewriter.replaceOp(op, floatValue); return success(); @@ -1720,29 +1720,29 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); - Value negHalf = builder.create(op.getOperand(), cstNegHalf); + Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf); // Select only the inverse sqrt of positive normals (denormals are // flushed to zero). - Value ltMinMask = builder.create( + Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); - Value infMask = builder.create(arith::CmpFPredicate::OEQ, + Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, op.getOperand(), cstPosInf); - Value notNormalFiniteMask = builder.create(ltMinMask, infMask); + Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask); // Compute an approximate result. Value yApprox = handleMultidimensionalVectors( builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { - return builder.create(operands); + return x86vector::RsqrtOp::create(builder, operands); }); // Do a single step of Newton-Raphson iteration to improve the approximation. // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). // It is essential to evaluate the inner term like this because forming // y_n^2 may over- or underflow. - Value inner = builder.create(negHalf, yApprox); - Value fma = builder.create(yApprox, inner, cstOnePointFive); - Value yNewton = builder.create(yApprox, fma); + Value inner = arith::MulFOp::create(builder, negHalf, yApprox); + Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive); + Value yNewton = arith::MulFOp::create(builder, yApprox, fma); // Select the result of the Newton-Raphson step for positive normal arguments. // For other arguments, choose the output of the intrinsic. This will @@ -1750,7 +1750,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, // x is zero or a positive denormalized float (equivalent to flushing positive // denormalized inputs to zero). Value res = - builder.create(notNormalFiniteMask, yApprox, yNewton); + arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton); rewriter.replaceOp(op, res); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index f630c48cdcaa1..ddb698e1bf501 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -87,10 +87,10 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, // TODO: support more types. return TypeSwitch(slot.elemType) .Case([&](MemRefType t) { - return builder.create(getLoc(), t); + return memref::AllocaOp::create(builder, getLoc(), t); }) .Default([&](Type t) { - return builder.create(getLoc(), t, + return arith::ConstantOp::create(builder, getLoc(), t, builder.getZeroAttr(t)); }); } @@ -137,7 +137,7 @@ DenseMap memref::AllocaOp::destructure( for (Attribute usedIndex : usedIndices) { Type elemType = memrefType.getTypeAtIndex(usedIndex); MemRefType elemPtr = MemRefType::get({}, elemType); - auto subAlloca = builder.create(getLoc(), elemPtr); + auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr); newAllocators.push_back(subAlloca); slotMap.try_emplace(usedIndex, {subAlloca.getResult(), elemType}); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index d1a9920aa66c5..cb5fe132494b3 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -213,9 +214,9 @@ struct SimplifyAllocConst : public OpRewritePattern { assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create( - alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(), - alloc.getAlignmentAttr()); + auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType, + dynamicSizes, alloc.getSymbolOperands(), + alloc.getAlignmentAttr()); // Insert a cast so we have the same type as the old alloc. rewriter.replaceOpWithNewOp(alloc, alloc.getType(), newAlloc); return success(); @@ -836,7 +837,7 @@ void DimOp::getAsmResultNames(function_ref setNameFn) { void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { auto loc = result.location; - Value indexValue = builder.create(loc, index); + Value indexValue = arith::ConstantIndexOp::create(builder, loc, index); build(builder, result, source, indexValue); } @@ -1083,9 +1084,9 @@ struct DimOfMemRefReshape : public OpRewritePattern { rewriter.setInsertionPointAfter(reshape); Location loc = dim.getLoc(); Value load = - rewriter.create(loc, reshape.getShape(), dim.getIndex()); + LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex()); if (load.getType() != dim.getType()) - load = rewriter.create(loc, dim.getType(), load); + load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load); rewriter.replaceOp(dim, load); return success(); } @@ -1358,8 +1359,9 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, assert(isa(maybeConstant) && "The constified value should be either unchanged (i.e., == result) " "or a constant"); - Value constantVal = rewriter.create( - loc, llvm::cast(cast(maybeConstant)).getInt()); + Value constantVal = arith::ConstantIndexOp::create( + rewriter, loc, + llvm::cast(cast(maybeConstant)).getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { // modifyOpInPlace: lambda cannot capture structured bindings in C++17 // yet. @@ -2587,8 +2589,9 @@ struct CollapseShapeOpMemRefCastFolder rewriter.modifyOpInPlace( op, [&]() { op.getSrcMutable().assign(cast.getSource()); }); } else { - Value newOp = rewriter.create( - op->getLoc(), cast.getSource(), op.getReassociationIndices()); + Value newOp = + CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(), + op.getReassociationIndices()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); } return success(); @@ -3045,15 +3048,15 @@ SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, Value offset = op.isDynamicOffset(idx) ? op.getDynamicOffset(idx) - : b.create(loc, op.getStaticOffset(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx)); Value size = op.isDynamicSize(idx) ? op.getDynamicSize(idx) - : b.create(loc, op.getStaticSize(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx)); Value stride = op.isDynamicStride(idx) ? op.getDynamicStride(idx) - : b.create(loc, op.getStaticStride(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx)); res.emplace_back(Range{offset, size, stride}); } return res; @@ -3212,8 +3215,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern { if (!resultType) return failure(); - Value newSubView = rewriter.create( - subViewOp.getLoc(), resultType, castOp.getSource(), + Value newSubView = SubViewOp::create( + rewriter, subViewOp.getLoc(), resultType, castOp.getSource(), subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(), subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(), subViewOp.getStaticStrides()); @@ -3534,9 +3537,9 @@ struct ViewOpShapeFolder : public OpRewritePattern { return failure(); // Create new ViewOp. - auto newViewOp = rewriter.create( - viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), - viewOp.getByteShift(), newOperands); + auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType, + viewOp.getOperand(0), viewOp.getByteShift(), + newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), newViewOp); return success(); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 89640ac323b68..525ae1e15bca6 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" @@ -156,9 +157,10 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, Type resultType = alloca.getResult().getType(); OpBuilder builder(rewriter.getContext()); // TODO: Add a better builder for this. - globalOp = builder.create( - loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), - TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + globalOp = memref::GlobalOp::create( + builder, loc, StringAttr::get(ctx, "alloca"), + StringAttr::get(ctx, "private"), TypeAttr::get(resultType), + Attribute{}, UnitAttr{}, IntegerAttr{}); symbolTable.insert(globalOp); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp index c433415944323..a2b689ed13cbf 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp @@ -22,11 +22,11 @@ struct DefaultAllocationInterface DefaultAllocationInterface, memref::AllocOp> { static std::optional buildDealloc(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc) + return memref::DeallocOp::create(builder, alloc.getLoc(), alloc) .getOperation(); } static std::optional buildClone(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc) + return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc) .getResult(); } static ::mlir::HoistingKind getHoistingKind() { @@ -35,7 +35,7 @@ struct DefaultAllocationInterface static ::std::optional<::mlir::Operation *> buildPromotedAlloc(OpBuilder &builder, Value alloc) { Operation *definingOp = alloc.getDefiningOp(); - return builder.create( + return memref::AllocaOp::create(builder, definingOp->getLoc(), cast(definingOp->getResultTypes()[0]), definingOp->getOperands(), definingOp->getAttrs()); } @@ -52,7 +52,7 @@ struct DefaultReallocationInterface DefaultAllocationInterface, memref::ReallocOp> { static std::optional buildDealloc(OpBuilder &builder, Value realloc) { - return builder.create(realloc.getLoc(), realloc) + return memref::DeallocOp::create(builder, realloc.getLoc(), realloc) .getOperation(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index d25ddb41aa4eb..4fcceeaa33deb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -125,7 +125,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { } AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); - Value result = rewriter.create( + Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map, affineApplyOperands); offsets.push_back(result); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index d2a032688fb6d..51970f019d27f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -99,7 +99,7 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); IntegerType dstType = builder.getIntegerType(targetBits); - return builder.create(loc, dstType, bitOffset); + return arith::IndexCastOp::create(builder, loc, dstType, bitOffset); } /// When writing a subbyte size, masked bitwise operations are used to only @@ -112,14 +112,14 @@ static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, auto dstIntegerType = builder.getIntegerType(dstBits); auto maskRightAlignedAttr = builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); - Value maskRightAligned = builder.create( + Value maskRightAligned = arith::ConstantOp::create(builder, loc, dstIntegerType, maskRightAlignedAttr); Value writeMaskInverse = - builder.create(loc, maskRightAligned, bitwidthOffset); + arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset); auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); Value flipVal = - builder.create(loc, dstIntegerType, flipValAttr); - return builder.create(loc, writeMaskInverse, flipVal); + arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr); + return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal); } /// Returns the scaled linearized index based on the `srcBits` and `dstBits` @@ -141,7 +141,7 @@ getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector &indices, Value memref) { auto stridedMetadata = - builder.create(loc, memref); + memref::ExtractStridedMetadataOp::create(builder, loc, memref); OpFoldResult linearizedIndices; std::tie(std::ignore, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( @@ -298,7 +298,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { // Special case 0-rank memref loads. Value bitsLoad; if (convertedType.getRank() == 0) { - bitsLoad = rewriter.create(loc, adaptor.getMemref(), + bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(), ValueRange{}); } else { // Linearize the indices of the original load instruction. Do not account @@ -306,7 +306,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { OpFoldResult linearizedIndices = getLinearizedSrcIndices( rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); - Value newLoad = rewriter.create( + Value newLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(), getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, dstBits)); @@ -315,7 +315,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { // Note, currently only the big-endian is supported. Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, dstBits, rewriter); - bitsLoad = rewriter.create(loc, newLoad, bitwidthOffset); + bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset); } // Get the corresponding bits. If the arith computation bitwidth equals @@ -326,13 +326,13 @@ struct ConvertMemRefLoad final : OpConversionPattern { Operation *result; auto resultTy = getTypeConverter()->convertType(oldElementType); if (resultTy == convertedElementType) { - auto mask = rewriter.create( + auto mask = arith::ConstantOp::create(rewriter, loc, convertedElementType, rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); - result = rewriter.create(loc, bitsLoad, mask); + result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask); } else { - result = rewriter.create(loc, resultTy, bitsLoad); + result = arith::TruncIOp::create(rewriter, loc, resultTy, bitsLoad); } rewriter.replaceOp(op, result->getResult(0)); @@ -415,12 +415,12 @@ struct ConvertMemrefStore final : OpConversionPattern { } Location loc = op.getLoc(); - Value extendedInput = rewriter.create(loc, dstIntegerType, + Value extendedInput = arith::ExtUIOp::create(rewriter, loc, dstIntegerType, adaptor.getValue()); // Special case 0-rank memref stores. No need for masking. if (convertedType.getRank() == 0) { - rewriter.create(loc, arith::AtomicRMWKind::assign, + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign, extendedInput, adaptor.getMemref(), ValueRange{}); rewriter.eraseOp(op); @@ -437,14 +437,14 @@ struct ConvertMemrefStore final : OpConversionPattern { dstBits, bitwidthOffset, rewriter); // Align the value to write with the destination bits Value alignedVal = - rewriter.create(loc, extendedInput, bitwidthOffset); + arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset); // Clear destination bits - rewriter.create(loc, arith::AtomicRMWKind::andi, + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), storeIndices); // Write srcs bits to destination - rewriter.create(loc, arith::AtomicRMWKind::ori, + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), storeIndices); rewriter.eraseOp(op); @@ -506,7 +506,7 @@ struct ConvertMemRefSubview final : OpConversionPattern { } // Transform the offsets, sizes and strides according to the emulation. - auto stridedMetadata = rewriter.create( + auto stridedMetadata = memref::ExtractStridedMetadataOp::create(rewriter, loc, subViewOp.getViewSource()); OpFoldResult linearizedIndices; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index a617029ce470f..64e117fb77bcf 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -51,15 +51,15 @@ struct MemRefReshapeOpConverter : public OpRewritePattern { Value size; // Load dynamic sizes from the shape input, use constants for static dims. if (op.getType().isDynamicDim(i)) { - Value index = rewriter.create(loc, i); - size = rewriter.create(loc, op.getShape(), index); + Value index = arith::ConstantIndexOp::create(rewriter, loc, i); + size = memref::LoadOp::create(rewriter, loc, op.getShape(), index); if (!isa(size.getType())) - size = rewriter.create( + size = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), size); sizes[i] = size; } else { auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i)); - size = rewriter.create(loc, sizeAttr); + size = arith::ConstantOp::create(rewriter, loc, sizeAttr); sizes[i] = sizeAttr; } if (stride) @@ -69,10 +69,10 @@ struct MemRefReshapeOpConverter : public OpRewritePattern { if (i > 0) { if (stride) { - stride = rewriter.create(loc, stride, size); + stride = arith::MulIOp::create(rewriter, loc, stride, size); } else if (op.getType().isDynamicDim(i)) { - stride = rewriter.create( - loc, rewriter.create(loc, staticStride), + stride = arith::MulIOp::create(rewriter, + loc, arith::ConstantIndexOp::create(rewriter, loc, staticStride), size); } else { staticStride *= op.getType().getDimSize(i); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp index 7475d442b7b9a..b54fa61714323 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp @@ -73,7 +73,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern { if (ShapedType::isDynamic(inputSize)) { Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc, rewriter.getIndexAttr(0)); - currSize = rewriter.create(loc, op.getSource(), dimZero) + currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero) .getResult(); } @@ -88,9 +88,9 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // the old buffer is smaller than the requested size. Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize); Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize); - Value cond = rewriter.create(loc, arith::CmpIPredicate::ult, + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, lhs, rhs); - auto ifOp = rewriter.create( + auto ifOp = scf::IfOp::create(rewriter, loc, cond, [&](OpBuilder &builder, Location loc) { // Allocate the new buffer. If it is a dynamic memref we need to pass @@ -100,25 +100,25 @@ struct ExpandReallocOpPattern : public OpRewritePattern { if (op.getDynamicResultSize()) dynamicSizeOperands.push_back(op.getDynamicResultSize()); - Value newAlloc = builder.create( + Value newAlloc = memref::AllocOp::create(builder, loc, op.getResult().getType(), dynamicSizeOperands, op.getAlignmentAttr()); // Take a subview of the new (bigger) buffer such that we can copy the // old values over (the copy operation requires both operands to have // the same shape). - Value subview = builder.create( + Value subview = memref::SubViewOp::create(builder, loc, newAlloc, ArrayRef{rewriter.getIndexAttr(0)}, ArrayRef{currSize}, ArrayRef{rewriter.getIndexAttr(1)}); - builder.create(loc, op.getSource(), subview); + memref::CopyOp::create(builder, loc, op.getSource(), subview); // Insert the deallocation of the old buffer only if requested // (enabled by default). if (emitDeallocs) - builder.create(loc, op.getSource()); + memref::DeallocOp::create(builder, loc, op.getSource()); - builder.create(loc, newAlloc); + scf::YieldOp::create(builder, loc, newAlloc); }, [&](OpBuilder &builder, Location loc) { // We need to reinterpret-cast here because either the input or output @@ -126,11 +126,11 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // dynamic or vice-versa. If both are static and the original buffer // is already bigger than the requested size, the cast represents a // subview operation. - Value casted = builder.create( + Value casted = memref::ReinterpretCastOp::create(builder, loc, cast(op.getResult().getType()), op.getSource(), rewriter.getIndexAttr(0), ArrayRef{targetSize}, ArrayRef{rewriter.getIndexAttr(1)}); - builder.create(loc, casted); + scf::YieldOp::create(builder, loc, casted); }); rewriter.replaceOp(op, ifOp.getResult(0)); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 2ba798f48ac7c..f581b6490fadb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -66,7 +66,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = - rewriter.create(origLoc, source); + memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source); auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); #ifndef NDEBUG @@ -577,7 +577,7 @@ static FailureOr resolveReshapeStridedMetadata( unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = - rewriter.create(origLoc, source); + memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source); // Collect statically known information. auto [strides, offset] = sourceType.getStridesAndOffset(); @@ -828,14 +828,14 @@ struct ExtractStridedMetadataOpAllocFolder if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); else - results.push_back(rewriter.create( + results.push_back(memref::ReinterpretCastOp::create(rewriter, loc, baseBufferType, allocLikeOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); } // Offset. - results.push_back(rewriter.create(loc, offset)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset)); for (OpFoldResult size : sizes) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); @@ -900,19 +900,19 @@ struct ExtractStridedMetadataOpGetGlobalFolder if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); else - results.push_back(rewriter.create( + results.push_back(memref::ReinterpretCastOp::create(rewriter, loc, baseBufferType, getGlobalOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); // Offset. - results.push_back(rewriter.create(loc, offset)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset)); for (auto size : sizes) - results.push_back(rewriter.create(loc, size)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, size)); for (auto stride : strides) - results.push_back(rewriter.create(loc, stride)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, stride)); rewriter.replaceOp(op, results); return success(); @@ -1009,7 +1009,7 @@ class ExtractStridedMetadataOpReinterpretCastFolder results.resize_for_overwrite(rank * 2 + 2); auto newExtractStridedMetadata = - rewriter.create( + memref::ExtractStridedMetadataOp::create(rewriter, loc, reinterpretCastOp.getSource()); // Register the base_buffer. @@ -1083,7 +1083,7 @@ class ExtractStridedMetadataOpCastFolder results.resize_for_overwrite(rank * 2 + 2); auto newExtractStridedMetadata = - rewriter.create(loc, + memref::ExtractStridedMetadataOp::create(rewriter, loc, castOp.getSource()); // Register the base_buffer. @@ -1143,7 +1143,7 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder if (!memSpaceCastOp) return failure(); auto newExtractStridedMetadata = - rewriter.create( + memref::ExtractStridedMetadataOp::create(rewriter, loc, memSpaceCastOp.getSource()); SmallVector results(newExtractStridedMetadata.getResults()); // As with most other strided metadata rewrite patterns, don't introduce @@ -1158,7 +1158,7 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder MemRefType::Builder newTypeBuilder(baseBufferType); newTypeBuilder.setMemorySpace( memSpaceCastOp.getResult().getType().getMemorySpace()); - results[0] = rewriter.create( + results[0] = memref::MemorySpaceCastOp::create(rewriter, loc, Type{newTypeBuilder}, baseBuffer); } else { results[0] = nullptr; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 2f5c9436fb8c7..65f398f6dc5a8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -42,7 +42,7 @@ static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, memref::LoadOp loadOp, Value srcMemRef, ArrayRef indices) { Location loc = loadOp.getLoc(); - return rewriter.create(loc, srcMemRef, indices, + return memref::LoadOp::create(rewriter, loc, srcMemRef, indices, loadOp.getNontemporal()); } @@ -72,7 +72,7 @@ static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, memref::StoreOp storeOp, Value srcMemRef, ArrayRef indices) { Location loc = storeOp.getLoc(); - return rewriter.create(loc, storeOp.getValueToStore(), + return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(), srcMemRef, indices, storeOp.getNontemporal()); } @@ -104,7 +104,7 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, Value srcMemRef, ArrayRef indices) { Location loc = ldMatrixOp.getLoc(); - return rewriter.create( + return nvgpu::LdMatrixOp::create(rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); } @@ -132,7 +132,7 @@ rebuildTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp transferReadOp, Value srcMemRef, ArrayRef indices) { Location loc = transferReadOp.getLoc(); - return rewriter.create( + return vector::TransferReadOp::create(rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices, transferReadOp.getPermutationMap(), transferReadOp.getPadding(), transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); @@ -150,7 +150,7 @@ rebuildTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, Value srcMemRef, ArrayRef indices) { Location loc = transferWriteOp.getLoc(); - return rewriter.create( + return vector::TransferWriteOp::create(rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices, transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), transferWriteOp.getInBoundsAttr()); @@ -183,7 +183,7 @@ getGenericOpViewSizeForEachDim(RewriterBase &rewriter, LoadStoreLikeOp loadStoreLikeOp) { Location loc = loadStoreLikeOp.getLoc(); auto extractStridedMetadataOp = - rewriter.create( + memref::ExtractStridedMetadataOp::create(rewriter, loc, getSrcMemRef(loadStoreLikeOp)); SmallVector srcSizes = extractStridedMetadataOp.getConstifiedMixedSizes(); @@ -267,12 +267,12 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern { // apply them properly to the input indices. // Therefore the strides multipliers are simply ones. auto subview = - rewriter.create(loc, /*source=*/srcMemRef, + memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef, /*offsets=*/indices, /*sizes=*/sizes, /*strides=*/ones); // Rewrite the load/store with the subview as the base pointer. SmallVector zeros(loadStoreRank, - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( rewriter, loadStoreLikeOp, subview.getResult(), zeros); rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index e9729a4766a0a..1a73cc6dd948a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -45,7 +45,7 @@ using namespace mlir; static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in) { if (Attribute offsetAttr = dyn_cast(in)) { - return rewriter.create( + return arith::ConstantIndexOp::create(rewriter, loc, cast(offsetAttr).getInt()); } return cast(in); @@ -65,7 +65,7 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, } memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, source); + memref::ExtractStridedMetadataOp::create(rewriter, loc, source); auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); OpFoldResult linearizedIndices; @@ -79,7 +79,7 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, getAsOpFoldResult(indices)); return std::make_pair( - rewriter.create( + memref::ReinterpretCastOp::create(rewriter, loc, source, /* offset = */ linearizedInfo.linearizedOffset, /* shapes = */ @@ -116,7 +116,7 @@ template static void castAllocResult(T oper, T newOper, Location loc, PatternRewriter &rewriter) { memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, oper); + memref::ExtractStridedMetadataOp::create(rewriter, loc, oper); rewriter.replaceOpWithNewOp( oper, cast(oper.getType()), newOper, /*offset=*/rewriter.getIndexAttr(0), @@ -130,62 +130,62 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Location loc = op->getLoc(); llvm::TypeSwitch(op.getOperation()) .template Case([&](auto oper) { - auto newAlloc = rewriter.create( + auto newAlloc = memref::AllocOp::create(rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloc, loc, rewriter); }) .template Case([&](auto oper) { - auto newAlloca = rewriter.create( + auto newAlloca = memref::AllocaOp::create(rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloca, loc, rewriter); }) .template Case([&](auto op) { - auto newLoad = rewriter.create( + auto newLoad = memref::LoadOp::create(rewriter, loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case([&](auto op) { - auto newStore = rewriter.create( + auto newStore = memref::StoreOp::create(rewriter, loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case([&](auto op) { - auto newLoad = rewriter.create( + auto newLoad = vector::LoadOp::create(rewriter, loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case([&](auto op) { - auto newStore = rewriter.create( + auto newStore = vector::StoreOp::create(rewriter, loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case([&](auto op) { - auto newMaskedLoad = rewriter.create( + auto newMaskedLoad = vector::MaskedLoadOp::create(rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), op.getPassThru()); newMaskedLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedLoad.getResult()); }) .template Case([&](auto op) { - auto newMaskedStore = rewriter.create( + auto newMaskedStore = vector::MaskedStoreOp::create(rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(), op.getValueToStore()); newMaskedStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedStore); }) .template Case([&](auto op) { - auto newTransferRead = rewriter.create( + auto newTransferRead = vector::TransferReadOp::create(rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); rewriter.replaceOp(op, newTransferRead.getResult()); }) .template Case([&](auto op) { - auto newTransferWrite = rewriter.create( + auto newTransferWrite = vector::TransferWriteOp::create(rewriter, loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); }) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 42c43ba8553a3..8052115610eaa 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -78,7 +78,7 @@ static LogicalResult resolveSourceIndicesExpandShape( llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; }); SmallVector groupIndices = llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; }); - Value collapsedIndex = rewriter.create( + Value collapsedIndex = affine::AffineLinearizeIndexOp::create(rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds); sourceIndices.push_back(collapsedIndex); } @@ -104,7 +104,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, ValueRange indices, SmallVectorImpl &sourceIndices) { // Note: collapse_shape requires a strided memref, we can do this. - auto metadata = rewriter.create( + auto metadata = memref::ExtractStridedMetadataOp::create(rewriter, loc, collapseShapeOp.getSrc()); SmallVector sourceSizes = metadata.getConstifiedMixedSizes(); for (auto [index, group] : @@ -119,7 +119,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, SmallVector basis = llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); - auto delinearize = rewriter.create( + auto delinearize = affine::AffineDelinearizeIndexOp::create(rewriter, loc, index, basis, /*hasOuterBound=*/true); llvm::append_range(sourceIndices, delinearize.getResults()); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 21361d2e9a2d7..33be43da47679 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -53,7 +53,7 @@ FailureOr memref::buildIndependentOp(OpBuilder &b, // Create a new memref::AllocaOp. Value newAllocaOp = - b.create(loc, newSizes, allocaOp.getType().getElementType()); + AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType()); // Create a memref::SubViewOp. SmallVector offsets(newSizes.size(), b.getIndexAttr(0)); @@ -73,10 +73,10 @@ propagateSubViewOp(RewriterBase &rewriter, MemRefType newResultType = SubViewOp::inferRankReducedResultType( op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); - Value newSubview = rewriter.create( + Value newSubview = SubViewOp::create(rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); - auto newConversionOp = rewriter.create( + auto newConversionOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(), op.getType(), newSubview); rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); return newConversionOp; @@ -108,7 +108,7 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter, SmallVector unrealizedConversions; for (const auto &it : llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { - unrealizedConversions.push_back(rewriter.create( + unrealizedConversions.push_back(UnrealizedConversionCastOp::create(rewriter, to->getLoc(), std::get<0>(it.value()).getType(), std::get<1>(it.value()))); rewriter.replaceAllUsesWith(from->getResult(it.index()), diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index c475d92e0658e..a1b8821f38e7e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -64,7 +64,7 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, subviewUse.getType().getShape(), cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); - Value newSubview = rewriter.create( + Value newSubview = memref::SubViewOp::create(rewriter, subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); @@ -178,7 +178,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, Location loc = allocOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocOp); - auto mbAlloc = rewriter.create( + auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); @@ -212,7 +212,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, // Strides is [1, 1 ... 1 ]. MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( originalShape, mbMemRefType, offsets, sizes, strides); - Value subview = rewriter.create(loc, dstMemref, mbAlloc, + Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); @@ -225,7 +225,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(deallocOp); auto newDeallocOp = - rewriter.create(deallocOp->getLoc(), mbAlloc); + memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc); (void)newDeallocOp; LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); rewriter.eraseOp(deallocOp); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index d6fcb8d9f0501..40dfad91b8959 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -278,7 +278,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, if (!callOp) continue; Operation *newCallOp = - builder.create(userOp->getLoc(), callOp.getCalleeAttr(), + func::CallOp::create(builder, userOp->getLoc(), callOp.getCalleeAttr(), resultTypes, userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp index e6b9e2f7e8213..ea7f394f48ba4 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp @@ -116,10 +116,10 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter, // Update the type. newRes.setType(reifiedTy); if (isa(reifiedTy)) { - newResults.push_back(rewriter.create(loc, oldTy, newRes)); + newResults.push_back(tensor::CastOp::create(rewriter, loc, oldTy, newRes)); } else { assert(isa(reifiedTy) && "expected a memref type"); - newResults.push_back(rewriter.create(loc, oldTy, newRes)); + newResults.push_back(memref::CastOp::create(rewriter, loc, oldTy, newRes)); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 89a3895d06ba5..6a81a15f30e47 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -69,7 +69,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern { Location loc = dimOp->getLoc(); rewriter.replaceOpWithNewOp( dimOp, resultShape, - rewriter.create(loc, *dimIndex).getResult()); + arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 7bf7c7b8e024c..e15bf379840b4 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -41,15 +41,15 @@ struct AssumeAlignmentOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto assumeOp = cast(op); - Value ptr = builder.create( + Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc, assumeOp.getMemref()); - Value rest = builder.create( + Value rest = arith::RemUIOp::create(builder, loc, ptr, - builder.create(loc, assumeOp.getAlignment())); - Value isAligned = builder.create( + arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment())); + Value isAligned = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest, - builder.create(loc, 0)); - builder.create( + arith::ConstantIndexOp::create(builder, loc, 0)); + cf::AssertOp::create(builder, loc, isAligned, RuntimeVerifiableOpInterface::generateErrorMessage( op, "memref is not aligned to " + @@ -72,12 +72,12 @@ struct CastOpInterface if (isa(srcType)) { // Check rank. - Value srcRank = builder.create(loc, castOp.getSource()); + Value srcRank = RankOp::create(builder, loc, castOp.getSource()); Value resultRank = - builder.create(loc, resultType.getRank()); - Value isSameRank = builder.create( + arith::ConstantIndexOp::create(builder, loc, resultType.getRank()); + Value isSameRank = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); - builder.create( + cf::AssertOp::create(builder, loc, isSameRank, RuntimeVerifiableOpInterface::generateErrorMessage(op, "rank mismatch")); @@ -96,8 +96,8 @@ struct CastOpInterface MemRefType::get(dynamicShape, resultType.getElementType(), stridedLayout, resultType.getMemorySpace()); Value helperCast = - builder.create(loc, dynStridesType, castOp.getSource()); - auto metadataOp = builder.create(loc, helperCast); + CastOp::create(builder, loc, dynStridesType, castOp.getSource()); + auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, helperCast); // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { @@ -111,12 +111,12 @@ struct CastOpInterface continue; Value srcDimSz = - builder.create(loc, castOp.getSource(), it.index()); + DimOp::create(builder, loc, castOp.getSource(), it.index()); Value resultDimSz = - builder.create(loc, it.value()); - Value isSameSz = builder.create( + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameSz = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); - builder.create( + cf::AssertOp::create(builder, loc, isSameSz, RuntimeVerifiableOpInterface::generateErrorMessage( op, "size mismatch of dim " + std::to_string(it.index()))); @@ -133,10 +133,10 @@ struct CastOpInterface // Static/dynamic offset -> dynamic offset does not need verification. Value srcOffset = metadataOp.getResult(1); Value resultOffsetVal = - builder.create(loc, resultOffset); - Value isSameOffset = builder.create( + arith::ConstantIndexOp::create(builder, loc, resultOffset); + Value isSameOffset = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); - builder.create( + cf::AssertOp::create(builder, loc, isSameOffset, RuntimeVerifiableOpInterface::generateErrorMessage( op, "offset mismatch")); @@ -151,10 +151,10 @@ struct CastOpInterface Value srcStride = metadataOp.getResult(2 + resultType.getRank() + it.index()); Value resultStrideVal = - builder.create(loc, it.value()); - Value isSameStride = builder.create( + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameStride = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); - builder.create( + cf::AssertOp::create(builder, loc, isSameStride, RuntimeVerifiableOpInterface::generateErrorMessage( op, "stride mismatch of dim " + std::to_string(it.index()))); @@ -187,7 +187,7 @@ struct CopyOpInterface auto getDimSize = [&](Value memRef, MemRefType type, int64_t dim) -> Value { return type.isDynamicDim(dim) - ? builder.create(loc, memRef, dim).getResult() + ? DimOp::create(builder, loc, memRef, dim).getResult() : builder .create(loc, type.getDimSize(dim)) @@ -195,9 +195,9 @@ struct CopyOpInterface }; Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i); Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i); - Value sameDimSize = builder.create( + Value sameDimSize = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim); - builder.create( + cf::AssertOp::create(builder, loc, sameDimSize, RuntimeVerifiableOpInterface::generateErrorMessage( op, "size of " + std::to_string(i) + @@ -212,9 +212,9 @@ struct DimOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto dimOp = cast(op); - Value rank = builder.create(loc, dimOp.getSource()); - Value zero = builder.create(loc, 0); - builder.create( + Value rank = RankOp::create(builder, loc, dimOp.getSource()); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + cf::AssertOp::create(builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), RuntimeVerifiableOpInterface::generateErrorMessage( op, "index is out of bounds")); @@ -238,7 +238,7 @@ struct LoadStoreOpInterface } auto indices = loadStoreOp.getIndices(); - auto zero = builder.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); Value assertCond; for (auto i : llvm::seq(0, rank)) { Value dimOp = builder.createOrFold(loc, memref, i); @@ -248,7 +248,7 @@ struct LoadStoreOpInterface i > 0 ? builder.createOrFold(loc, assertCond, inBounds) : inBounds; } - builder.create( + cf::AssertOp::create(builder, loc, assertCond, RuntimeVerifiableOpInterface::generateErrorMessage( op, "out-of-bounds access")); @@ -266,10 +266,10 @@ struct SubViewOpInterface // For each dimension, assert that: // 0 <= offset < dim_size // 0 <= offset + (size - 1) * stride < dim_size - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); auto metadataOp = - builder.create(loc, subView.getSource()); + ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); @@ -282,20 +282,20 @@ struct SubViewOpInterface Value dimSize = metadataOp.getSizes()[i]; Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - builder.create( + cf::AssertOp::create(builder, loc, offsetInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "offset " + std::to_string(i) + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = builder.create(loc, size, one); + Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = - builder.create(loc, sizeMinusOne, stride); + arith::MulIOp::create(builder, loc, sizeMinusOne, stride); Value lastPos = - builder.create(loc, offset, sizeMinusOneTimesStride); + arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - builder.create( + cf::AssertOp::create(builder, loc, lastPosInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "subview runs out-of-bounds along dimension " + @@ -316,7 +316,7 @@ struct ExpandShapeOpInterface for (const auto &it : llvm::enumerate(expandShapeOp.getReassociationIndices())) { Value srcDimSz = - builder.create(loc, expandShapeOp.getSrc(), it.index()); + DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index()); int64_t groupSz = 1; bool foundDynamicDim = false; for (int64_t resultDim : it.value()) { @@ -331,14 +331,14 @@ struct ExpandShapeOpInterface groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); } Value staticResultDimSz = - builder.create(loc, groupSz); + arith::ConstantIndexOp::create(builder, loc, groupSz); // staticResultDimSz must divide srcDimSz evenly. Value mod = - builder.create(loc, srcDimSz, staticResultDimSz); - Value isModZero = builder.create( + arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz); + Value isModZero = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, mod, - builder.create(loc, 0)); - builder.create( + arith::ConstantIndexOp::create(builder, loc, 0)); + cf::AssertOp::create(builder, loc, isModZero, RuntimeVerifiableOpInterface::generateErrorMessage( op, "static result dims in reassoc group do not " diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 5a36984c9013c..7167289ec0e14 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -92,7 +93,7 @@ SmallVector mlir::mesh::getMixedAsValues(OpBuilder b, values.emplace_back(*(dyn++)); } else { TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); - values.emplace_back(b.create(loc, type, val)); + values.emplace_back(arith::ConstantOp::create(b, loc, type, val)); } } return values; @@ -317,10 +318,10 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, if (!newShardOp) { auto shardingOp = - builder.create(operandValue.getLoc(), sharding); - newShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ false); + ShardingOp::create(builder, operandValue.getLoc(), sharding); + newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue, + shardingOp, + /*annotate_for_users*/ false); } operandValue.replaceUsesWithIf( newShardOp, [operandOp, operandValue](OpOperand &use) { @@ -331,9 +332,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, return; } - auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, - newShardOp.getSharding(), - /*annotate_for_users*/ true); + auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp, + newShardOp.getSharding(), + /*annotate_for_users*/ true); newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2); } @@ -379,10 +380,10 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, builder.setInsertionPoint(operandOp); auto shardingOp = - builder.create(operand.get().getLoc(), sharding); + ShardingOp::create(builder, operand.get().getLoc(), sharding); auto newShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ true); + ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp, + /*annotate_for_users*/ true); IRRewriter rewriter(builder); rewriter.replaceUsesWithIf( operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { @@ -396,8 +397,8 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, builder.setInsertionPoint(newShardOp); auto newPreceedingShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ false); + ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp, + /*annotate_for_users*/ false); rewriter.replaceUsesWithIf( newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) { return use.getOwner() == newShardOp.getOperation(); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp index 4c14b1c0ea4bb..b5b3a4d21f2af 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp @@ -92,7 +92,7 @@ struct MeshShapeFolder newShapeOpMeshAxes.push_back(opMeshAxes[i]); } else { // Fold static mesh axes. - newResults[i] = builder.create( + newResults[i] = arith::ConstantOp::create(builder, builder.getIndexAttr(meshAxisSize)); } } @@ -100,7 +100,7 @@ struct MeshShapeFolder // Leave only the dynamic mesh axes to be queried. if (!newShapeOpMeshAxes.empty()) { MeshShapeOp newShapeOp = - builder.create(mesh.getSymName(), newShapeOpMeshAxes); + MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes); for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; } diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index a0284b093da94..6476e8b689baa 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -267,7 +267,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); - Value allGatherResult = builder.create( + Value allGatherResult = AllGatherOp::create(builder, RankedTensorType::get(allGatherResultShape.getShape(), allGatherResultShape.getElementType()), mesh.getSymName(), SmallVector({splitMeshAxis}), sourceShard, @@ -275,7 +275,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = cast>( - builder.create(targetShape, allGatherResult).getResult()); + tensor::CastOp::create(builder, targetShape, allGatherResult).getResult()); return {targetShard, targetSharding}; } @@ -400,7 +400,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, targetTensorAxis); - Value allToAllResult = builder.create( + Value allToAllResult = AllToAllOp::create(builder, RankedTensorType::get(allToAllResultShape.getShape(), allToAllResultShape.getElementType()), mesh.getSymName(), SmallVector({meshAxis}), sourceShard, @@ -408,7 +408,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = cast>( - builder.create(targetShape, allToAllResult).getResult()); + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); return {targetShard, targetSharding}; } @@ -479,13 +479,13 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // Extract core from source and copy into destination core. auto noVals = ValueRange{}; - auto initVal = builder.create( + auto initVal = tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape, sourceShard.getType().getElementType()); - auto core = builder.create( + auto core = tensor::ExtractSliceOp::create(builder, sourceShard.getLoc(), RankedTensorType::get(coreShape, sourceShard.getType().getElementType()), sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides); - auto initOprnd = builder.create( + auto initOprnd = tensor::InsertSliceOp::create(builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs, coreShape, strides); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index f08ef75d8a004..c85394924608d 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -49,10 +49,10 @@ struct ProcessMultiIndexOpLowering ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); - Value linearIndex = builder.create(mesh); - ValueRange meshShape = builder.create(mesh).getResults(); + Value linearIndex = ProcessLinearIndexOp::create(builder, mesh); + ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults(); SmallVector completeMultiIndex = - builder.create(linearIndex, meshShape) + affine::AffineDelinearizeIndexOp::create(builder, linearIndex, meshShape) .getMultiIndex(); SmallVector multiIndex; ArrayRef opMeshAxes = op.getAxes(); @@ -101,32 +101,32 @@ struct AllSliceOpLowering ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); - Value zero = builder.create(builder.getIndexAttr(0)); + Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0)); Operation::result_range processInGroupMultiIndex = - builder.create(mesh.getSymName(), op.getMeshAxes()) + ProcessMultiIndexOp::create(builder, mesh.getSymName(), op.getMeshAxes()) .getResults(); Operation::result_range processGroupShape = - builder.create(mesh.getSymName(), op.getMeshAxes()) + MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes()) .getResult(); Value processGroupSize = createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); int64_t sliceAxis = op.getSliceAxis().getSExtValue(); Value operandSliceAxisSize = - builder.create(op.getOperand(), sliceAxis); + tensor::DimOp::create(builder, op.getOperand(), sliceAxis); Value operandSliceAxisSizeModProcessGroupSize = - builder.create(operandSliceAxisSize, processGroupSize); - Value isTargetShapeExactlyDivisible = builder.create( + arith::RemUIOp::create(builder, operandSliceAxisSize, processGroupSize); + Value isTargetShapeExactlyDivisible = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize, zero); - builder.create(isTargetShapeExactlyDivisible, + cf::AssertOp::create(builder, isTargetShapeExactlyDivisible, "Slicing a tensor with axis size that is " "not exactly divisible by the " "mesh process group size is not supported."); Value resultSliceAxisSize = - builder.create(operandSliceAxisSize, processGroupSize); + arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); @@ -139,7 +139,7 @@ struct AllSliceOpLowering if (i == sliceAxis) { sizes.emplace_back(resultSliceAxisSize); } else { - Value dimSize = builder.create(op.getOperand(), i); + Value dimSize = tensor::DimOp::create(builder, op.getOperand(), i); sizes.emplace_back(dimSize); } } @@ -152,10 +152,10 @@ struct AllSliceOpLowering resultSliceAxisSize); SmallVector strides( operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1)); - Value slice = builder.create( + Value slice = tensor::ExtractSliceOp::create(builder, op.getOperand(), offsets, sizes, strides); Value newResult = - builder.create(op.getResult().getType(), slice); + tensor::CastOp::create(builder, op.getResult().getType(), slice); rewriter.replaceAllUsesWith(op.getResult(), newResult); return success(); @@ -201,7 +201,7 @@ TypedValue createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, ImplicitLocOpBuilder &builder) { Operation::result_range meshShape = - builder.create(mesh, axes).getResults(); + mesh::MeshShapeOp::create(builder, mesh, axes).getResults(); return cast>(arith::createProduct( builder, builder.getLoc(), llvm::to_vector_of(meshShape), builder.getIndexType())); @@ -212,13 +212,13 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, ArrayRef meshAxes, ImplicitLocOpBuilder &builder) { Operation::result_range processGroupShape = - builder.create(mesh, meshAxes).getResult(); + MeshShapeOp::create(builder, mesh, meshAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); auto res = dyn_cast(processInGroupLinearIndex); if (!res) - res = builder.create( + res = arith::ConstantIndexOp::create(builder, cast(cast(processInGroupLinearIndex)).getInt()); return cast>(res); } @@ -227,7 +227,7 @@ TypedValue createProcessLinearIndex(StringRef mesh, ArrayRef meshAxes, ImplicitLocOpBuilder &builder) { return createProcessLinearIndex( - mesh, builder.create(mesh, meshAxes).getResults(), + mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(), meshAxes, builder); } } // namespace mlir::mesh diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index abbdb6a0f53ec..4a89699b80023 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index d2c94b124cdfb..85317a68c9f00 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -334,13 +334,13 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // Location loc = asyncCopyOp->getLoc(); Value dstElements = - rewriter.create(loc, asyncCopyOp.getDstElementsAttr()); + arith::ConstantOp::create(rewriter, loc, asyncCopyOp.getDstElementsAttr()); Value originalSrcElement = asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; - Value c0Index = rewriter.create(loc, 0); - auto srcElements = rewriter.create( + Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto srcElements = arith::SelectOp::create(rewriter, loc, predicate, originalSrcElement, c0Index); - auto asyncCopyZeroFillOp = rewriter.create( + auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, @@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, for (auto indexing : indexings) { Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); - auto load = b.create(loc, memref, ValueRange{row, col}); + auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col}); res.push_back(load); } return res; @@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); - Value res = b.create(loc, vt, loads[0]); + Value res = vector::SplatOp::create(b, loc, vt, loads[0]); foreachIndividualVectorElement( res, /*applyFn=*/ @@ -697,7 +697,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( }, /*reduceFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { - res = b.create(loc, v, res, indices); + res = vector::InsertOp::create(b, loc, v, res, indices); }); return res; @@ -715,7 +715,7 @@ SmallVector MmaSyncBuilder::buildMemRefStores( Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); Operation *store = - b.create(loc, val, memref, ValueRange{row, col}); + memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col}); res.push_back(store); } return res; @@ -730,7 +730,7 @@ SmallVector MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( vectorToStore, /*applyFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { - return b.create(loc, vectorToStore, indices); + return vector::ExtractOp::create(b, loc, vectorToStore, indices); }, /*reduceFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { @@ -810,7 +810,7 @@ FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { rhsIndexFn, rhsShape); Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); - res = b.create(loc, lhs, rhs, res, info.mmaShape, + res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); @@ -832,7 +832,7 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( } Location loc = linalgOp.getLoc(); // TODO: more robust computation of laneId, for now assume a single warp. - Value laneId = rewriter.create( + Value laneId = gpu::ThreadIdOp::create(rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x); if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) fail = false; @@ -897,12 +897,12 @@ SmallVector HopperBuilder::buildPredicateLoadsOnThread0( ArrayRef> sharedMemBuffers, TypedValue barrier) { SmallVector loadOps; - Value zero = rewriter.create(loc, 0); - Value tidx = rewriter.create(loc, gpu::Dimension::x); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); Value cond = - rewriter.create(loc, arith::CmpIPredicate::eq, tidx, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, tidx, zero); // clang-format off - rewriter.create( + scf::IfOp::create(rewriter, /*location=*/loc, /*conditional=*/cond, /*thenBuilder=*/ @@ -917,14 +917,14 @@ SmallVector HopperBuilder::buildPredicateLoadsOnThread0( // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. // This may or may not have perf implications. buildBarrierArriveTx(barrier, sizes); - rewriter.create(loc); + scf::YieldOp::create(rewriter, loc); }, /*elseBuilder=*/ [&](OpBuilder &lb, Location loc) { // TODO: is this for no-thread divergence? // Should we just yield the size and hoist? buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0)); - rewriter.create(loc); + scf::YieldOp::create(rewriter, loc); }); // clang-format on return loadOps; @@ -939,14 +939,14 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { TypedValue HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value barrier = rewriter.create( + Value barrier = nvgpu::MBarrierCreateOp::create(rewriter, loc, nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); - Value zero = rewriter.create(loc, 0); - rewriter.create( + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + nvgpu::MBarrierInitOp::create(rewriter, loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero, Value()); - rewriter.create(loc); + gpu::BarrierOp::create(rewriter, loc); return cast>(barrier); } @@ -955,7 +955,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, gpu::LaunchOp launchOp) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launchOp); - Value unrankedMemRef = rewriter.create( + Value unrankedMemRef = memref::CastOp::create(rewriter, loc, UnrankedMemRefType::get(memref.getType().getElementType(), memref.getType().getMemorySpace()), @@ -966,7 +966,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value desc = rewriter.create( + Value desc = nvgpu::TmaCreateDescriptorOp::create(rewriter, loc, nvgpu::TensorMapDescriptorType::get( rewriter.getContext(), @@ -985,8 +985,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad( TypedValue barrier, SmallVectorImpl &loadOps) { MLIRContext *ctx = rewriter.getContext(); - Value zero = rewriter.create(loc, 0); - Operation *loadOp = rewriter.create( + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, Value(), Value()); loadOps.push_back(loadOp); @@ -1012,22 +1012,22 @@ void HopperBuilder::buildBarrierArriveTx( OpFoldResult size = affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes); Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size); - Value zero = rewriter.create(loc, 0); - rewriter.create(loc, barrier, sizeVal, zero, + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero, Value()); } void HopperBuilder::buildTryWaitParity( TypedValue barrier) { Type i1 = rewriter.getI1Type(); - Value parity = rewriter.create(loc, i1, 0); + Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0); // 10M is an arbitrary, not too small or too big number to specify the number // of ticks before retry. // TODO: hoist this in a default dialect constant. Value ticksBeforeRetry = - rewriter.create(loc, 10000000); - Value zero = rewriter.create(loc, 0); - rewriter.create(loc, barrier, parity, + arith::ConstantIndexOp::create(rewriter, loc, 10000000); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity, ticksBeforeRetry, zero); } diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp index 10bc1993ffd96..6a167b1cb157b 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -110,17 +110,17 @@ static Value buildNumReadElements(OpBuilder &b, Location loc, for (auto [pos, sz] : llvm::zip(transferMask->extractPosition, transferMask->createMaskOp->getOperands())) { Value cmp = - b.create(loc, arith::CmpIPredicate::slt, - b.create(loc, pos), sz); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, + arith::ConstantIndexOp::create(b, loc, pos), sz); if (!cond) { cond = cmp; continue; } - cond = b.create(loc, cmp, cond); + cond = arith::AndIOp::create(b, loc, cmp, cond); } - return b.create( + return arith::SelectOp::create(b, loc, cond, transferMask->createMaskOp->getOperands().back(), - b.create(loc, 0)); + arith::ConstantIndexOp::create(b, loc, 0)); } /// Return "true" if the conversion to async copy is supported by "async copy". @@ -252,7 +252,7 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * numElements) / 8; // bypass_l1 only possible with 16 byte transfer. - Value token = rewriter.create( + Value token = nvgpu::DeviceAsyncCopyOp::create(rewriter, writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp), /*src=*/loadBase, @@ -265,10 +265,10 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, } // Create the group and wait for it right after. - Value groupToken = rewriter.create( + Value groupToken = nvgpu::DeviceAsyncCreateGroupOp::create(rewriter, op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens); - rewriter.create(op->getLoc(), groupToken, + nvgpu::DeviceAsyncWaitOp::create(rewriter, op->getLoc(), groupToken, nullptr); // Clean up old stores. for (Operation *writeOp : group) diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp index 72f7296a865f8..3c6b189606c6f 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -75,27 +75,27 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc, int64_t mask = (1LL << (m - n)) - 1; if (permuteEveryN > 1) mask = mask << llvm::Log2_64(permuteEveryN); - Value srcBits = b.create(loc, mask); - srcBits = b.create(loc, src, srcBits); + Value srcBits = arith::ConstantIndexOp::create(b, loc, mask); + srcBits = arith::AndIOp::create(b, loc, src, srcBits); // Use the src bits to permute the target bits b[N:M] containing the // vector offset. if (permuteEveryN > 1) { int64_t shlBits = n - llvm::Log2_64(permuteEveryN); if (shlBits > 0) { - Value finalShiftVal = b.create(loc, shlBits); + Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, shlBits); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } else if (shlBits < 0) { - Value finalShiftVal = b.create(loc, -1 * shlBits); + Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, -1 * shlBits); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } } else { - Value finalShiftVal = b.create(loc, n); + Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, n); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } Value permutedVectorIdx = - b.create(loc, indices[tgtDim], srcBits); + arith::XOrIOp::create(b, loc, indices[tgtDim], srcBits); return permutedVectorIdx; } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 80c807e774a7e..d4707619fa273 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index ffc84781f77ff..8fa59488678ee 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index d3f7c9798b9b8..ffe53bbe80607 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp index 9b1f11d835282..5cc84ab29e246 100644 --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/FunctionImplementation.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c488144508128..3695262155387 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Ptr/IR/PtrOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index e23a0d6aba825..cd4198b431669 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 793db73575b4f..a378b09c27c41 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -72,7 +72,7 @@ Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, // Create tensor splat auto tensorConstant = - builder.create(loc, scalar, referenceShape); + tensor::SplatOp::create(builder, loc, scalar, referenceShape); return tensorConstant; } @@ -94,21 +94,21 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, // Get unranked input shape and total size auto *context = builder.getContext(); auto shapeType = shape::getExtentTensorType(context); - auto inputShape = builder.create(loc, shapeType, input); - Value inputSize = builder.create( + auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input); + Value inputSize = shape::NumElementsOp::create(builder, loc, builder.getIndexType(), inputShape); // Turn input size into 1D tensor auto flatShapeType = shape::getExtentTensorType(context, 1); auto flatInputShape = - builder.create(loc, flatShapeType, inputSize); + tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize); // Reshape input tensor into 1D auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto flatInputType = RankedTensorType::get({ShapedType::kDynamic}, elementType); - auto flatInput = builder.create(loc, flatInputType, input, + auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input, flatInputShape); return std::make_pair(flatInput, inputShape); } @@ -142,30 +142,30 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, auto *context = builder.getContext(); auto indexType = builder.getIndexType(); auto shapeType = shape::getExtentTensorType(context); - auto inputShape = builder.create(loc, shapeType, input); + auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input); // Get shape and sizes on left and right of axis - auto axisValue = builder.create(loc, axis); - auto axisNextValue = builder.create(loc, axis + 1); + auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis); + auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1); auto shapeLeft = builder .create(loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) .getResult(0); auto sizeLeft = - builder.create(loc, indexType, shapeLeft); + shape::NumElementsOp::create(builder, loc, indexType, shapeLeft); auto shapeRight = builder .create(loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) .getResult(1); auto sizeRight = - builder.create(loc, indexType, shapeRight); + shape::NumElementsOp::create(builder, loc, indexType, shapeRight); // Compute flat input shape as a 3-element 1D tensor - auto axisSizeValue = builder.create(loc, axisSize); + auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize); auto flatShapeType = shape::getExtentTensorType(context, 3); - auto flatInputShape = builder.create( + auto flatInputShape = tensor::FromElementsOp::create(builder, loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); // Reshape input to 3D tensor @@ -173,7 +173,7 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, auto elementType = inputType.getElementType(); auto flatInputType = RankedTensorType::get( {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); - auto flatInput = builder.create(loc, flatInputType, input, + auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input, flatInputShape); return std::make_pair(flatInput, inputShape); @@ -192,7 +192,7 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto unrankedType = UnrankedTensorType::get(elementType); - return builder.create(loc, unrankedType, input, + return tensor::ReshapeOp::create(builder, loc, unrankedType, input, inputShape); } @@ -215,7 +215,7 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc, auto tensorType = RankedTensorType::get({(int64_t)scales.size()}, expressedType); auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); - return builder.create(loc, tensorType, scalesAttr); + return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr); } // Create a tensor constant containing all zero points in a per-channel @@ -239,7 +239,7 @@ Value materializePerChannelZeroPoints( auto tensorType = RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); - return builder.create(loc, tensorType, zeroPointsAttr); + return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr); } // Create a tensor constant containing all scales in a sub-channel quantized @@ -263,7 +263,7 @@ Value materializeSubChannelScales( auto tensorType = RankedTensorType::get(scales.getType().getShape(), expressedType); auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); - return builder.create(loc, tensorType, scalesAttr); + return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr); } // Create a tensor constant containing all zero points in a sub-channel @@ -287,7 +287,7 @@ Value materializeSubChannelZeroPoints( auto tensorType = RankedTensorType::get(zeroPoints.getType().getShape(), storageType); auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); - return builder.create(loc, tensorType, zeroPointsAttr); + return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr); } // Clamp the given scalar or tensor input using the storage bounds encoded in @@ -314,9 +314,9 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, // Materialize bounds auto inputType = input.getType(); auto storageType = quantizedType.getStorageType(); - auto storageMinScalar = builder.create( + auto storageMinScalar = arith::ConstantIntOp::create(builder, loc, storageType, quantizedType.getStorageTypeMin()); - auto storageMaxScalar = builder.create( + auto storageMaxScalar = arith::ConstantIntOp::create(builder, loc, storageType, quantizedType.getStorageTypeMax()); auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, inputType, inputShape); @@ -325,11 +325,11 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, // Clamp if (quantizedType.isSigned()) { - input = builder.create(loc, input, storageMin); - input = builder.create(loc, input, storageMax); + input = arith::MaxSIOp::create(builder, loc, input, storageMin); + input = arith::MinSIOp::create(builder, loc, input, storageMax); } else { - input = builder.create(loc, input, storageMin); - input = builder.create(loc, input, storageMax); + input = arith::MaxUIOp::create(builder, loc, input, storageMin); + input = arith::MinUIOp::create(builder, loc, input, storageMax); } return input; } @@ -338,16 +338,16 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { if (isSigned) - return builder.create(loc, resultType, input); - return builder.create(loc, resultType, input); + return arith::FPToSIOp::create(builder, loc, resultType, input); + return arith::FPToUIOp::create(builder, loc, resultType, input); } // Emit op 'arith.sitofp' or 'arith.uitofp'. Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { if (isSigned) - return builder.create(loc, resultType, input); - return builder.create(loc, resultType, input); + return arith::SIToFPOp::create(builder, loc, resultType, input); + return arith::UIToFPOp::create(builder, loc, resultType, input); } // Quantize a scalar or ranked tensor value. The stored value is clamped using @@ -362,7 +362,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape); // Scale input - auto scaledValue = builder.create(loc, input, scale); + auto scaledValue = arith::DivFOp::create(builder, loc, input, scale); // Skip unnecessary computations if no zero point is given Value storedValueFloat = scaledValue; @@ -377,7 +377,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, // Add zero point to stored value storedValueFloat = - builder.create(loc, scaledValue, zeroPoint); + arith::AddFOp::create(builder, loc, scaledValue, zeroPoint); } // Convert stored value to storage type @@ -418,11 +418,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input, quantizedType.isSigned()); // Subtract zero point to stored value - result = builder.create(loc, result, zeroPoint); + result = arith::SubFOp::create(builder, loc, result, zeroPoint); } // Multiply by scale - result = builder.create(loc, result, scale); + result = arith::MulFOp::create(builder, loc, result, scale); return result; } @@ -477,11 +477,11 @@ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, auto storageType = quantizedType.getStorageType(); auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale()); - auto scale = builder.create(loc, expressedType, scaleAttr); + auto scale = arith::ConstantOp::create(builder, loc, expressedType, scaleAttr); auto zeroPointAttr = builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); auto zeroPoint = - builder.create(loc, storageType, zeroPointAttr); + arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr); auto inputShape = getScalarOrTensorShape(builder, loc, input); return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, @@ -546,7 +546,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, ? quantizedType.getStorageType() : quantizedType.getExpressedType(); auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, elementType); + Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); @@ -572,7 +572,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, convertRanked(builder, loc, op, input, {}, scale, zeroPoint, quantizedType); - builder.create(loc, result); + linalg::YieldOp::create(builder, loc, result); }) .getResult(0); @@ -642,7 +642,7 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, ? quantizedType.getStorageType() : quantizedType.getExpressedType(); auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, elementType); + Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); @@ -675,7 +675,7 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, convertRanked(builder, loc, op, input, {}, scale, zeroPoint, quantizedType); - builder.create(loc, result); + linalg::YieldOp::create(builder, loc, result); }) .getResult(0); @@ -729,7 +729,7 @@ struct DequantizeCastOpConversion // Convert quantized input to storage type auto storageScalarOrTensorType = getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); - input = rewriter.create( + input = quant::StorageCastOp::create(rewriter, loc, storageScalarOrTensorType, input); auto result = convertQuantized(rewriter, loc, op, input, quantizedType); diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp index 4009faa21576d..f5e0aef365ff7 100644 --- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -46,7 +46,7 @@ class QuantizedTypeConverter : public TypeConverter { static Value materializeConversion(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - return builder.create(loc, type, + return quant::StorageCastOp::create(builder, loc, type, llvm::getSingleElement(inputs)); } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 5a3bd984530db..8ead566703ab4 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -85,7 +86,7 @@ void SCFDialect::initialize() { /// Default callback for IfOp builders. Inserts a yield without arguments. void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { - builder.create(loc); + scf::YieldOp::create(builder, loc); } /// Verifies that the first block of the given `region` is terminated by a @@ -241,13 +242,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern { Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); rewriter.setInsertionPointToEnd(prevBlock); - rewriter.create(op.getLoc(), &op.getRegion().front()); + cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front()); for (Block &blk : op.getRegion()) { if (YieldOp yieldOp = dyn_cast(blk.getTerminator())) { rewriter.setInsertionPoint(yieldOp); - rewriter.create(yieldOp.getLoc(), postBlock, - yieldOp.getResults()); + cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock, + yieldOp.getResults()); rewriter.eraseOp(yieldOp); } } @@ -557,8 +558,8 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, rewriter.setInsertionPoint(getOperation()); auto inits = llvm::to_vector(getInitArgs()); inits.append(newInitOperands.begin(), newInitOperands.end()); - scf::ForOp newLoop = rewriter.create( - getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, + scf::ForOp newLoop = scf::ForOp::create( + rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); @@ -673,8 +674,8 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { Value dst = parallelInsertSliceOp.getDest(); Value src = parallelInsertSliceOp.getSource(); if (llvm::isa(src.getType())) { - results.push_back(rewriter.create( - forallOp.getLoc(), dst.getType(), src, dst, + results.push_back(tensor::InsertSliceOp::create( + rewriter, forallOp.getLoc(), dst.getType(), src, dst, parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), parallelInsertSliceOp.getStrides(), parallelInsertSliceOp.getStaticOffsets(), @@ -722,8 +723,8 @@ LoopNest mlir::scf::buildLoopNest( ValueRange currentIterArgs = iterArgs; Location currentLoc = loc; for (unsigned i = 0, e = lbs.size(); i < e; ++i) { - auto loop = builder.create( - currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, + auto loop = scf::ForOp::create( + builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange args) { ivs.push_back(iv); @@ -742,7 +743,7 @@ LoopNest mlir::scf::buildLoopNest( // For all loops but the innermost, yield the results of the nested loop. for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) { builder.setInsertionPointToEnd(loops[i].getBody()); - builder.create(loc, loops[i + 1].getResults()); + scf::YieldOp::create(builder, loc, loops[i + 1].getResults()); } // In the body of the innermost loop, call the body building function if any @@ -756,7 +757,7 @@ LoopNest mlir::scf::buildLoopNest( "loop nest body must return as many values as loop has iteration " "arguments"); builder.setInsertionPointToEnd(loops.back().getBody()); - builder.create(loc, results); + scf::YieldOp::create(builder, loc, results); // Return the loops. ValueVector nestResults; @@ -801,8 +802,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, } // 2. Create the new forOp shell. - scf::ForOp newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + scf::ForOp newForOp = scf::ForOp::create( + rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newIterOperands); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -831,7 +832,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, clonedYieldOp.getOperand(yieldIdx)); SmallVector newYieldOperands = clonedYieldOp.getOperands(); newYieldOperands[yieldIdx] = castOut; - rewriter.create(newForOp.getLoc(), newYieldOperands); + scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands); rewriter.eraseOp(clonedYieldOp); // 6. Inject an outgoing cast op after the forOp. @@ -926,9 +927,9 @@ struct ForOpIterArgsFolder : public OpRewritePattern { if (!canonicalize) return failure(); - scf::ForOp newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newIterArgs); + scf::ForOp newForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), newIterArgs); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -970,8 +971,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern { for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) if (keepMask[idx]) filteredOperands.push_back(mergedTerminator.getOperand(idx)); - rewriter.create(mergedTerminator.getLoc(), - filteredOperands); + scf::YieldOp::create(rewriter, mergedTerminator.getLoc(), + filteredOperands); }; rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); @@ -1111,7 +1112,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern { op, replaceAndCastForOpIterArg( rewriter, op, iterOpOperand, incomingCast.getSource(), [](OpBuilder &b, Location loc, Type type, Value source) { - return b.create(loc, type, source); + return tensor::CastOp::create(b, loc, type, source); })); return success(); } @@ -1685,8 +1686,8 @@ struct ForallOpIterArgsFolder : public OpRewritePattern { // Step 3. Create a new scf.forall op with the new shared_outs' operands // fetched earlier - auto newForallOp = rewriter.create( - forallOp.getLoc(), forallOp.getMixedLowerBound(), + auto newForallOp = scf::ForallOp::create( + rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, forallOp.getMapping(), /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); @@ -1782,9 +1783,9 @@ struct ForallOpSingleOrZeroIterationDimsFolder // Replace the loop by a lower-dimensional loop. ForallOp newOp; - newOp = rewriter.create(loc, newMixedLowerBounds, - newMixedUpperBounds, newMixedSteps, - op.getOutputs(), std::nullopt, nullptr); + newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds, + newMixedUpperBounds, newMixedSteps, + op.getOutputs(), std::nullopt, nullptr); newOp.getBodyRegion().getBlocks().clear(); // The new loop needs to keep all attributes from the old one, except for // "operandSegmentSizes" and static loop bound attributes which capture @@ -1867,16 +1868,17 @@ struct FoldTensorCastOfOutputIntoForallOp // Create new loop. Location loc = forallOp.getLoc(); - auto newForallOp = rewriter.create( - loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), - forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), + auto newForallOp = ForallOp::create( + rewriter, loc, forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), + newOutputTensors, forallOp.getMapping(), [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { auto castBlockArgs = llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); for (auto [index, cast] : tensorCastProducers) { Value &oldTypeBBArg = castBlockArgs[index]; - oldTypeBBArg = nestedBuilder.create( - nestedLoc, cast.dstType, oldTypeBBArg); + oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc, + cast.dstType, oldTypeBBArg); } // Move old body into new parallel loop. @@ -1902,8 +1904,8 @@ struct FoldTensorCastOfOutputIntoForallOp SmallVector castResults = newForallOp.getResults(); for (auto &item : tensorCastProducers) { Value &oldTypeResult = castResults[item.first]; - oldTypeResult = rewriter.create(loc, item.second.dstType, - oldTypeResult); + oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType, + oldTypeResult); } rewriter.replaceOp(forallOp, castResults); return success(); @@ -2309,7 +2311,7 @@ struct RemoveUnusedResults : public OpRewritePattern { // Create a replacement operation with empty then and else regions. auto newOp = - rewriter.create(op.getLoc(), newTypes, op.getCondition()); + IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition()); rewriter.createBlock(&newOp.getThenRegion()); rewriter.createBlock(&newOp.getElseRegion()); @@ -2372,8 +2374,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern { if (nonHoistable.size() == op->getNumResults()) return failure(); - IfOp replacement = rewriter.create(op.getLoc(), nonHoistable, cond, - /*withElseRegion=*/false); + IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond, + /*withElseRegion=*/false); if (replacement.thenBlock()) rewriter.eraseBlock(replacement.thenBlock()); replacement.getThenRegion().takeBody(op.getThenRegion()); @@ -2398,8 +2400,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern { } else if (trueVal == falseVal) results[it.index()] = trueVal; else - results[it.index()] = rewriter.create( - op.getLoc(), cond, trueVal, falseVal); + results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(), + cond, trueVal, falseVal); } rewriter.setInsertionPointToEnd(replacement.thenBlock()); @@ -2450,8 +2452,8 @@ struct ConditionPropagation : public OpRewritePattern { changed = true; if (!constantTrue) - constantTrue = rewriter.create( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantTrue); }); @@ -2460,8 +2462,8 @@ struct ConditionPropagation : public OpRewritePattern { changed = true; if (!constantFalse) - constantFalse = rewriter.create( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + constantFalse = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantFalse); }); @@ -2547,8 +2549,8 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern { if (!trueVal && falseVal) { if (!opResult.use_empty()) { Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); - Value notCond = rewriter.create( - op.getLoc(), op.getCondition(), + Value notCond = arith::XOrIOp::create( + rewriter, op.getLoc(), op.getCondition(), constDialect ->materializeConstant(rewriter, rewriter.getIntegerAttr(i1Ty, 1), i1Ty, @@ -2661,8 +2663,8 @@ struct CombineIfs : public OpRewritePattern { SmallVector mergedTypes(prevIf.getResultTypes()); llvm::append_range(mergedTypes, nextIf.getResultTypes()); - IfOp combinedIf = rewriter.create( - nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); + IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes, + prevIf.getCondition(), /*hasElse=*/false); rewriter.eraseBlock(&combinedIf.getThenRegion().back()); rewriter.inlineRegionBefore(prevIf.getThenRegion(), @@ -2677,7 +2679,7 @@ struct CombineIfs : public OpRewritePattern { SmallVector mergedYields(thenYield.getOperands()); llvm::append_range(mergedYields, thenYield2.getOperands()); - rewriter.create(thenYield2.getLoc(), mergedYields); + YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields); rewriter.eraseOp(thenYield); rewriter.eraseOp(thenYield2); } @@ -2701,7 +2703,7 @@ struct CombineIfs : public OpRewritePattern { SmallVector mergedElseYields(elseYield.getOperands()); llvm::append_range(mergedElseYields, elseYield2.getOperands()); - rewriter.create(elseYield2.getLoc(), mergedElseYields); + YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields); rewriter.eraseOp(elseYield); rewriter.eraseOp(elseYield2); } @@ -2823,9 +2825,9 @@ struct CombineNestedIfs : public OpRewritePattern { } Location loc = op.getLoc(); - Value newCondition = rewriter.create( - loc, op.getCondition(), nestedIf.getCondition()); - auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition); + Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(), + nestedIf.getCondition()); + auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition); Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); SmallVector results; @@ -2833,8 +2835,9 @@ struct CombineNestedIfs : public OpRewritePattern { rewriter.setInsertionPoint(newIf); for (auto idx : elseYieldsToUpgradeToSelect) - results[idx] = rewriter.create( - op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); + results[idx] = + arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(), + thenYield[idx], elseYield[idx]); rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); rewriter.setInsertionPointToEnd(newIf.thenBlock()); @@ -2842,7 +2845,7 @@ struct CombineNestedIfs : public OpRewritePattern { if (!elseYield.empty()) { rewriter.createBlock(&newIf.getElseRegion()); rewriter.setInsertionPointToEnd(newIf.elseBlock()); - rewriter.create(loc, elseYield); + YieldOp::create(rewriter, loc, elseYield); } rewriter.replaceOp(op, results); return success(); @@ -3160,8 +3163,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder } // Replace the parallel loop by lower-dimensional parallel loop. auto newOp = - rewriter.create(op.getLoc(), newLowerBounds, newUpperBounds, - newSteps, op.getInitVals(), nullptr); + ParallelOp::create(rewriter, op.getLoc(), newLowerBounds, + newUpperBounds, newSteps, op.getInitVals(), nullptr); // Erase the empty block that was inserted by the builder. rewriter.eraseBlock(newOp.getBody()); // Clone the loop body and remap the block arguments of the collapsed loops @@ -3541,8 +3544,8 @@ struct WhileConditionTruth : public OpRewritePattern { if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) { if (!std::get<1>(yieldedAndBlockArgs).use_empty()) { if (!constantTrue) - constantTrue = rewriter.create( - op.getLoc(), term.getCondition().getType(), + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), term.getCondition().getType(), rewriter.getBoolAttr(true)); rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs), @@ -3684,8 +3687,8 @@ struct RemoveLoopInvariantArgsFromBeforeBlock rewriter.replaceOpWithNewOp(yieldOp, newYieldOpArgs); } - auto newWhile = - rewriter.create(op.getLoc(), op.getResultTypes(), newInitArgs); + auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(), + newInitArgs); Block &newBeforeBlock = *rewriter.createBlock( &newWhile.getBefore(), /*insertPt*/ {}, @@ -3807,8 +3810,8 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern { newCondOpArgs); } - auto newWhile = rewriter.create(op.getLoc(), newAfterBlockType, - op.getOperands()); + auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType, + op.getOperands()); Block &newAfterBlock = *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, @@ -3914,7 +3917,7 @@ struct WhileUnusedResult : public OpRewritePattern { } auto newWhile = - rewriter.create(op.getLoc(), newResultTypes, op.getInits()); + WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits()); Block &newAfterBlock = *rewriter.createBlock( &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); @@ -4043,8 +4046,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern { Location loc = op.getLoc(); auto newWhileOp = - rewriter.create(loc, op.getResultTypes(), newInits, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); + WhileOp::create(rewriter, loc, op.getResultTypes(), newInits, + /*beforeBody*/ nullptr, /*afterBody*/ nullptr); Block &newBeforeBlock = *newWhileOp.getBeforeBody(); Block &newAfterBlock = *newWhileOp.getAfterBody(); @@ -4091,9 +4094,10 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern { ValueRange argsRange(newArgs); Location loc = op.getLoc(); - auto newWhileOp = rewriter.create( - loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, - /*afterBody*/ nullptr); + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(), + /*beforeBody*/ nullptr, + /*afterBody*/ nullptr); Block &newBeforeBlock = *newWhileOp.getBeforeBody(); Block &newAfterBlock = *newWhileOp.getAfterBody(); @@ -4187,8 +4191,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern { for (auto &&[i, j] : llvm::enumerate(*mapping)) newResultTypes[j] = loop.getResult(i).getType(); - auto newLoop = rewriter.create( - loop.getLoc(), newResultTypes, loop.getInits(), + auto newLoop = WhileOp::create( + rewriter, loop.getLoc(), newResultTypes, loop.getInits(), /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); auto newBefore = newLoop.getBeforeBody(); auto newAfter = newLoop.getAfterBody(); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 57c27231f2144..145afac6df15e 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" using namespace mlir; @@ -163,7 +164,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); scf::ExecuteRegionOp executeRegionOp = - b.create(op->getLoc(), op->getResultTypes()); + scf::ExecuteRegionOp::create(b, op->getLoc(), op->getResultTypes()); { OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); @@ -172,7 +173,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, assert(clonedRegion.empty() && "expected empty region"); b.inlineRegionBefore(op->getRegions().front(), clonedRegion, clonedRegion.end()); - b.create(op->getLoc(), clonedOp->getResults()); + scf::YieldOp::create(b, op->getLoc(), clonedOp->getResults()); } b.replaceOp(op, executeRegionOp.getResults()); return executeRegionOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index d36d91249ed36..0214ce3090656 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -42,7 +42,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) { // iter_arg's layout map must be changed (see uses of `castBuffer`). assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && "scf.while op bufferization: cast incompatible"); - return b.create(buffer.getLoc(), type, buffer).getResult(); + return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult(); } /// Helper function for loop bufferization. Return "true" if the given value @@ -191,7 +191,7 @@ struct ExecuteRegionOpInterface // Create new op and move over region. auto newOp = - rewriter.create(op->getLoc(), newResultTypes); + scf::ExecuteRegionOp::create(rewriter, op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); // Bufferize every block. @@ -205,7 +205,7 @@ struct ExecuteRegionOpInterface SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { if (isa(it.value())) { - newResults.push_back(rewriter.create( + newResults.push_back(bufferization::ToTensorOp::create(rewriter, executeRegionOp.getLoc(), it.value(), newOp->getResult(it.index()))); } else { @@ -261,7 +261,7 @@ struct IfOpInterface // Create new op. rewriter.setInsertionPoint(ifOp); auto newIfOp = - rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), + scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); // Move over then/else blocks. @@ -374,7 +374,7 @@ struct IndexSwitchOpInterface // Create new op. rewriter.setInsertionPoint(switchOp); - auto newSwitchOp = rewriter.create( + auto newSwitchOp = scf::IndexSwitchOp::create(rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), switchOp.getCases().size()); @@ -769,7 +769,7 @@ struct ForOpInterface } // Construct a new scf.for op with memref instead of tensor values. - auto newForOp = rewriter.create( + auto newForOp = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), castedInitArgs); newForOp->setAttrs(forOp->getAttrs()); @@ -1005,7 +1005,7 @@ struct WhileOpInterface // Construct a new scf.while op with memref instead of tensor values. ValueRange argsRangeBefore(castedInitArgs); TypeRange argsTypesBefore(argsRangeBefore); - auto newWhileOp = rewriter.create( + auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(), argsTypesAfter, castedInitArgs); // Add before/after regions to the new op. @@ -1265,7 +1265,7 @@ struct ForallOpInterface forallOp.getBody()->getArguments().drop_front(rank), buffers)) { BlockArgument bbArg = std::get<0>(it); Value buffer = std::get<1>(it); - Value bufferAsTensor = rewriter.create( + Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(), bbArg.getType(), buffer); bbArg.replaceAllUsesWith(bufferAsTensor); } @@ -1274,7 +1274,7 @@ struct ForallOpInterface // introduced terminator. rewriter.setInsertionPoint(forallOp); ForallOp newForallOp; - newForallOp = rewriter.create( + newForallOp = ForallOp::create(rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), /*outputs=*/ValueRange(), forallOp.getMapping()); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index 3e93dc80b18ec..bf4eba3319a4d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -50,7 +50,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern { SmallVector initArgs; initArgs.push_back(forOp.getLowerBound()); llvm::append_range(initArgs, forOp.getInitArgs()); - auto whileOp = rewriter.create(forOp.getLoc(), lcvTypes, initArgs, + auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs, forOp->getAttrs()); // 'before' region contains the loop condition and forwarding of iteration @@ -58,10 +58,10 @@ struct ForLoopLoweringPattern : public OpRewritePattern { auto *beforeBlock = rewriter.createBlock( &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); - auto cmpOp = rewriter.create( + auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt, beforeBlock->getArgument(0), forOp.getUpperBound()); - rewriter.create(whileOp.getLoc(), cmpOp.getResult(), + scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(), beforeBlock->getArguments()); // Inline for-loop body into an executeRegion operation in the "after" @@ -72,7 +72,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern { // Add induction variable incrementation rewriter.setInsertionPointToEnd(afterBlock); - auto ivIncOp = rewriter.create( + auto ivIncOp = arith::AddIOp::create(rewriter, whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep()); // Rewrite uses of the for-loop block arguments to the new while-loop diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp index 44e6840b03a3d..b95604fa44cb9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp @@ -40,7 +40,7 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter, SmallVector steps = forallOp.getStep(rewriter); // Create empty scf.parallel op. - auto parallelOp = rewriter.create(loc, lbs, ubs, steps); + auto parallelOp = scf::ParallelOp::create(rewriter, loc, lbs, ubs, steps); rewriter.eraseBlock(¶llelOp.getRegion().front()); rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), parallelOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index bcecef5e6e0a9..f1ede64d01e0f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -279,24 +279,24 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { if (dynamicLoop) { Type t = ub.getType(); // pred = ub > lb + (i * step) - Value iv = rewriter.create( + Value iv = arith::AddIOp::create(rewriter, loc, lb, - rewriter.create( + arith::MulIOp::create(rewriter, loc, step, - rewriter.create( + arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(t, i)))); - predicates[i] = rewriter.create( + predicates[i] = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, iv, ub); } // special handling for induction variable as the increment is implicit. // iv = lb + i * step Type t = lb.getType(); - Value iv = rewriter.create( + Value iv = arith::AddIOp::create(rewriter, loc, lb, - rewriter.create( + arith::MulIOp::create(rewriter, loc, step, - rewriter.create(loc, + arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(t, i)))); setValueMapping(forOp.getInductionVar(), iv, i); for (Operation *op : opOrder) { @@ -332,7 +332,7 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { Value prevValue = valueMapping [forOp.getRegionIterArgs()[operand.getOperandNumber()]] [i - stages[op]]; - source = rewriter.create( + source = arith::SelectOp::create(rewriter, loc, predicates[predicateIdx], source, prevValue); } setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], @@ -444,14 +444,14 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( Type t = ub.getType(); Location loc = forOp.getLoc(); // newUb = ub - maxStage * step - Value maxStageValue = rewriter.create( + Value maxStageValue = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(t, maxStage)); Value maxStageByStep = - rewriter.create(loc, step, maxStageValue); - newUb = rewriter.create(loc, ub, maxStageByStep); + arith::MulIOp::create(rewriter, loc, step, maxStageValue); + newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep); } auto newForOp = - rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); // When there are no iter args, the loop body terminator will be created. // Since we always create it below, remove the terminator if it was created. @@ -483,14 +483,14 @@ LogicalResult LoopPipelinerInternal::createKernel( Type t = ub.getType(); for (unsigned i = 0; i < maxStage; i++) { // c = ub - (maxStage - i) * step - Value c = rewriter.create( + Value c = arith::SubIOp::create(rewriter, loc, ub, - rewriter.create( + arith::MulIOp::create(rewriter, loc, step, - rewriter.create( + arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); - Value pred = rewriter.create( + Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(), arith::CmpIPredicate::slt, newForOp.getInductionVar(), c); predicates[i] = pred; @@ -515,12 +515,12 @@ LogicalResult LoopPipelinerInternal::createKernel( // offset = (maxStage - stages[op]) * step Type t = step.getType(); - Value offset = rewriter.create( + Value offset = arith::MulIOp::create(rewriter, forOp.getLoc(), step, - rewriter.create( + arith::ConstantOp::create(rewriter, forOp.getLoc(), rewriter.getIntegerAttr(t, maxStage - stages[op]))); - Value iv = rewriter.create( + Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(), newForOp.getInductionVar(), offset); nestedNewOp->setOperand(operand->getOperandNumber(), iv); rewriter.setInsertionPointAfter(newOp); @@ -594,7 +594,7 @@ LogicalResult LoopPipelinerInternal::createKernel( auto defStage = stages.find(def); if (defStage != stages.end() && defStage->second < maxStage) { Value pred = predicates[defStage->second]; - source = rewriter.create( + source = arith::SelectOp::create(rewriter, pred.getLoc(), pred, source, newForOp.getBody() ->getArguments()[yieldOperand.getOperandNumber() + 1]); @@ -638,7 +638,7 @@ LogicalResult LoopPipelinerInternal::createKernel( maxStage - defStage->second + 1); } } - rewriter.create(forOp.getLoc(), yieldOperands); + scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands); return success(); } @@ -652,7 +652,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // removed by dead code if not used. auto createConst = [&](int v) { - return rewriter.create(loc, + return arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(t, v)); }; @@ -661,41 +661,41 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step Value zero = createConst(0); Value one = createConst(1); - Value stepLessZero = rewriter.create( + Value stepLessZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, step, zero); Value stepDecr = - rewriter.create(loc, stepLessZero, one, createConst(-1)); + arith::SelectOp::create(rewriter, loc, stepLessZero, one, createConst(-1)); - Value rangeDiff = rewriter.create(loc, ub, lb); - Value rangeIncrStep = rewriter.create(loc, rangeDiff, step); + Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb); + Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step); Value rangeDecr = - rewriter.create(loc, rangeIncrStep, stepDecr); - Value totalIterations = rewriter.create(loc, rangeDecr, step); + arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr); + Value totalIterations = arith::DivSIOp::create(rewriter, loc, rangeDecr, step); // If total_iters < max_stage, start the epilogue at zero to match the // ramp-up in the prologue. // start_iter = max(0, total_iters - max_stage) - Value iterI = rewriter.create(loc, totalIterations, + Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations, createConst(maxStage)); - iterI = rewriter.create(loc, zero, iterI); + iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI); // Capture predicates for dynamic loops. SmallVector predicates(maxStage + 1); for (int64_t i = 1; i <= maxStage; i++) { // newLastIter = lb + step * iterI - Value newlastIter = rewriter.create( - loc, lb, rewriter.create(loc, step, iterI)); + Value newlastIter = arith::AddIOp::create(rewriter, + loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI)); setValueMapping(forOp.getInductionVar(), newlastIter, i); // increment to next iterI - iterI = rewriter.create(loc, iterI, one); + iterI = arith::AddIOp::create(rewriter, loc, iterI, one); if (dynamicLoop) { // Disable stages when `i` is greater than total_iters. // pred = total_iters >= i - predicates[i] = rewriter.create( + predicates[i] = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); } } @@ -758,7 +758,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, unsigned nextVersion = currentVersion + 1; Value pred = predicates[currentVersion]; Value prevValue = valueMapping[mapVal][currentVersion]; - auto selOp = rewriter.create(loc, pred, pair.value(), + auto selOp = arith::SelectOp::create(rewriter, loc, pred, pair.value(), prevValue); returnValues[ri] = selOp; if (nextVersion <= maxStage) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index b71ec985fa6a1..b24e276272c2c 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -64,13 +64,13 @@ static void specializeParallelLoopForUnrolling(ParallelOp op) { Value cond; for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) { Value constant = - b.create(op.getLoc(), std::get<1>(bound)); - Value cmp = b.create(op.getLoc(), arith::CmpIPredicate::eq, + arith::ConstantIndexOp::create(b, op.getLoc(), std::get<1>(bound)); + Value cmp = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq, std::get<0>(bound), constant); - cond = cond ? b.create(op.getLoc(), cond, cmp) : cmp; + cond = cond ? arith::AndIOp::create(b, op.getLoc(), cond, cmp) : cmp; map.map(std::get<0>(bound), constant); } - auto ifOp = b.create(op.getLoc(), cond, /*withElseRegion=*/true); + auto ifOp = scf::IfOp::create(b, op.getLoc(), cond, /*withElseRegion=*/true); ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); ifOp.getElseBodyBuilder().clone(*op.getOperation()); op.erase(); @@ -95,11 +95,11 @@ static void specializeForLoopForUnrolling(ForOp op) { OpBuilder b(op); IRMapping map; - Value constant = b.create(op.getLoc(), minConstant); - Value cond = b.create(op.getLoc(), arith::CmpIPredicate::eq, + Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant); + Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq, bound, constant); map.map(bound, constant); - auto ifOp = b.create(op.getLoc(), cond, /*withElseRegion=*/true); + auto ifOp = scf::IfOp::create(b, op.getLoc(), cond, /*withElseRegion=*/true); ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); ifOp.getElseBodyBuilder().clone(*op.getOperation()); op.erase(); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index ad1267381c4f2..013c3b69620e2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -190,7 +190,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, IRRewriter b(builder); b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create( + auto newSecondPloop = ParallelOp::create(b, secondPloop.getLoc(), secondPloop.getLowerBound(), secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); @@ -212,7 +212,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); + auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs); for (auto &&[i, reg] : llvm::enumerate(llvm::concat( term1.getReductions(), term2.getReductions()))) { diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index 66f7bc27f82ff..2ca3b1e366bdf 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -58,27 +58,27 @@ std::pair mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, bool noMinMaxBounds) { OpBuilder b(op); - auto zero = b.create(op.getLoc(), 0); + auto zero = arith::ConstantIndexOp::create(b, op.getLoc(), 0); SmallVector tileSizeConstants; tileSizeConstants.reserve(op.getUpperBound().size()); for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) { if (i < tileSizes.size()) tileSizeConstants.push_back( - b.create(op.getLoc(), tileSizes[i])); + arith::ConstantIndexOp::create(b, op.getLoc(), tileSizes[i])); else // Just pick 1 for the remaining dimensions. tileSizeConstants.push_back( - b.create(op.getLoc(), 1)); + arith::ConstantIndexOp::create(b, op.getLoc(), 1)); } // Create the outer loop with adjusted steps. SmallVector newSteps; newSteps.reserve(op.getStep().size()); for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) { - newSteps.push_back(b.create(op.getLoc(), std::get<0>(step), + newSteps.push_back(arith::MulIOp::create(b, op.getLoc(), std::get<0>(step), std::get<1>(step))); } - auto outerLoop = b.create(op.getLoc(), op.getLowerBound(), + auto outerLoop = ParallelOp::create(b, op.getLoc(), op.getLowerBound(), op.getUpperBound(), newSteps); b.setInsertionPointToStart(outerLoop.getBody()); @@ -130,10 +130,10 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, // Otherwise, we dynamically compute the bound for // each iteration of the outer loop. newBounds.push_back( - b.create(op.getLoc(), b.getIndexType(), minMap, + affine::AffineMinOp::create(b, op.getLoc(), b.getIndexType(), minMap, ValueRange{newStep, upperBound, iv})); } - auto innerLoop = b.create( + auto innerLoop = ParallelOp::create(b, op.getLoc(), SmallVector(newBounds.size(), zero), newBounds, op.getStep()); @@ -141,20 +141,20 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, b.setInsertionPointToStart(innerLoop.getBody()); // Insert in-bound check Value inbound = - b.create(op.getLoc(), b.getIntegerType(1), 1); + arith::ConstantIntOp::create(b, op.getLoc(), b.getIntegerType(1), 1); for (auto [outerUpperBound, outerIV, innerIV, innerStep] : llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(), innerLoop.getInductionVars(), innerLoop.getStep())) { // %in_bound = %in_bound && // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) - Value index = b.create( - op.getLoc(), b.create(op.getLoc(), innerIV, innerStep), + Value index = arith::AddIOp::create(b, + op.getLoc(), arith::MulIOp::create(b, op.getLoc(), innerIV, innerStep), outerIV); - Value dimInbound = b.create( + Value dimInbound = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); - inbound = b.create(op.getLoc(), inbound, dimInbound); + inbound = arith::AndIOp::create(b, op.getLoc(), inbound, dimInbound); } - auto ifInbound = b.create(op.getLoc(), + auto ifInbound = IfOp::create(b, op.getLoc(), /*resultTypes*/ ArrayRef{}, inbound, /*hasElseRegion*/ false); ifInbound.getThenRegion().takeBody(op.getRegion()); @@ -162,12 +162,12 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, // Replace the scf.reduce terminator with an scf.yield terminator. Operation *reduceOp = thenBlock.getTerminator(); b.setInsertionPointToEnd(&thenBlock); - b.create(reduceOp->getLoc()); + scf::YieldOp::create(b, reduceOp->getLoc()); reduceOp->erase(); b.setInsertionPointToStart(innerLoop.getBody()); for (const auto &ivs : llvm::enumerate(llvm::zip( innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { - auto newIndex = b.create( + auto newIndex = arith::AddIOp::create(b, op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); thenBlock.getArgument(ivs.index()) .replaceAllUsesExcept(newIndex, newIndex); @@ -179,7 +179,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, for (auto ivs : llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) { Value innerIndex = std::get<0>(ivs); - auto newIndex = b.create(op.getLoc(), std::get<0>(ivs), + auto newIndex = arith::AddIOp::create(b, op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); innerIndex.replaceAllUsesExcept(newIndex, newIndex); } diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 09326242eec2a..41895e46d953d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -112,7 +112,7 @@ class ConvertForOpTypes // We can not do clone as the number of result types after conversion // might be different. - ForOp newOp = rewriter.create( + ForOp newOp = ForOp::create(rewriter, op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()), llvm::getSingleElement(adaptor.getUpperBound()), llvm::getSingleElement(adaptor.getStep()), @@ -142,7 +142,7 @@ class ConvertIfOpTypes ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - IfOp newOp = rewriter.create( + IfOp newOp = IfOp::create(rewriter, op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()), true); newOp->setAttrs(op->getAttrs()); @@ -171,7 +171,7 @@ class ConvertWhileOpTypes std::optional convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - auto newOp = rewriter.create(op.getLoc(), dstTypes, + auto newOp = WhileOp::create(rewriter, op.getLoc(), dstTypes, flattenValues(adaptor.getOperands())); for (auto i : {0u, 1u}) { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 995120ad8680e..f995f324a0dbd 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -450,7 +450,7 @@ static LogicalResult generateLoopNestUsingForOp( SmallVector ivs; for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { auto loop = - rewriter.create(loc, lb, ub, step, destinationTensors, + scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors, [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) {}); loops.push_back(loop); @@ -479,12 +479,12 @@ static LogicalResult generateLoopNestUsingForOp( resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - auto insertSlice = rewriter.create( + auto insertSlice = tensor::InsertSliceOp::create(rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize, resultStride); yieldedValues.push_back(insertSlice); } - rewriter.create(loc, yieldedValues); + scf::YieldOp::create(rewriter, loc, yieldedValues); // Add the scf.yield operations for all the outer loops. for (auto [outerLoop, innerLoop] : @@ -492,7 +492,7 @@ static LogicalResult generateLoopNestUsingForOp( MutableArrayRef(loops).drop_front())) { rewriter.setInsertionPointToEnd( cast(outerLoop.getOperation()).getBody()); - rewriter.create(outerLoop.getLoc(), innerLoop->getResults()); + scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults()); } return success(); } @@ -533,13 +533,13 @@ static LogicalResult generateLoopNestUsingForallOp( continue; nonZeroNumThreads.push_back(nt); } - forallOp = rewriter.create(loc, nonZeroNumThreads, + forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads, destinationTensors, mappingAttr); } else { SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = getLoopBounds(rewriter, loc, loopRanges, tileSizes); - forallOp = rewriter.create(loc, lbs, ubs, steps, + forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps, destinationTensors, mappingAttr); } loops.push_back(forallOp); @@ -561,7 +561,7 @@ static LogicalResult generateLoopNestUsingForallOp( SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - rewriter.create( + tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize, resultStride); } @@ -798,7 +798,7 @@ FailureOr yieldTiledValuesAndReplaceLoop( auto inits = llvm::to_vector(loopOp.getInitArgs()); inits.append(newInitOperands.begin(), newInitOperands.end()); - auto newLoop = rewriter.create( + auto newLoop = scf::ForOp::create(rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); @@ -829,7 +829,7 @@ FailureOr yieldTiledValuesAndReplaceLoop( resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - Value insert = rewriter.create( + Value insert = tensor::InsertSliceOp::create(rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, resultStride); newYieldValues.push_back(insert); @@ -851,7 +851,7 @@ FailureOr yieldTiledValuesAndReplaceLoop( rewriter.setInsertionPoint(loopOp); auto inits = llvm::to_vector(loopOp.getOutputs()); inits.append(newInitOperands.begin(), newInitOperands.end()); - auto newLoop = rewriter.create( + auto newLoop = scf::ForallOp::create(rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), loopOp.getMixedStep(), inits, loopOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); @@ -884,7 +884,7 @@ FailureOr yieldTiledValuesAndReplaceLoop( tiledValues, regionIterArgs, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - rewriter.create( + tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, resultStride); } @@ -935,7 +935,7 @@ static LogicalResult addInitOperandsToLoopNest( // Create a new loop with the new init values for this loop. SmallVector newInits = llvm::to_vector(forLoop.getInitArgs()); newInits.append(newInitValues.begin(), newInitValues.end()); - auto newLoop = rewriter.create( + auto newLoop = scf::ForOp::create(rewriter, forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), forLoop.getStep(), newInits, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); @@ -1419,7 +1419,7 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( rewriter.setInsertionPoint(tiledDestStyleOp); for (const auto &&[index, newRegionArg] : llvm::enumerate(newRegionIterArgs)) { - auto destSlice = rewriter.create( + auto destSlice = tensor::ExtractSliceOp::create(rewriter, loc, newRegionArg, offsetList[index], sizesList[index], SmallVector(offsetList[index].size(), rewriter.getIndexAttr(1))); @@ -2092,7 +2092,7 @@ cloneAsInsertSlice(RewriterBase &rewriter, template <> tensor::InsertSliceOp cloneAsInsertSlice( RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) { - return rewriter.create( + return tensor::InsertSliceOp::create(rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(), insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); @@ -2314,7 +2314,7 @@ mlir::scf::tileAndFuseConsumerOfSlices( rewriter.setInsertionPoint(tiledDestStyleOp); for (const auto &&[index, newRegionArg] : llvm::enumerate(newRegionIterArgs)) { - auto destSlice = rewriter.create( + auto destSlice = tensor::ExtractSliceOp::create(rewriter, loc, newRegionArg, resultOffsets[index], resultSizes[index], SmallVector(resultOffsets[index].size(), rewriter.getIndexAttr(1))); @@ -2391,7 +2391,7 @@ mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); Value strideVal = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); - auto loop = rewriter.create(op.getLoc(), offsetVal, sizeVal, + auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal, strideVal, ValueRange{}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index 29d6d2574a2be..14ded96516eab 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -191,7 +191,7 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, // dummy builder instead. auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; auto newLoop = - rewriter.create(loc, lb, ub, step, newArgs, emptyBuilder); + scf::ForOp::create(rewriter, loc, lb, ub, step, newArgs, emptyBuilder); Block *newBody = newLoop.getBody(); @@ -238,18 +238,18 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, rewriter.setInsertionPointAfter(newLoop); Value one; if (isa(step.getType())) { - one = rewriter.create(loc, 1); + one = arith::ConstantIndexOp::create(rewriter, loc, 1); } else { - one = rewriter.create(loc, step.getType(), 1); + one = arith::ConstantIntOp::create(rewriter, loc, step.getType(), 1); } - Value stepDec = rewriter.create(loc, step, one); - Value len = rewriter.create(loc, ub, lb); - len = rewriter.create(loc, len, stepDec); - len = rewriter.create(loc, len, step); - len = rewriter.create(loc, len, one); - Value res = rewriter.create(loc, len, step); - res = rewriter.create(loc, lb, res); + Value stepDec = arith::SubIOp::create(rewriter, loc, step, one); + Value len = arith::SubIOp::create(rewriter, loc, ub, lb); + len = arith::AddIOp::create(rewriter, loc, len, stepDec); + len = arith::DivSIOp::create(rewriter, loc, len, step); + len = arith::SubIOp::create(rewriter, loc, len, one); + Value res = arith::MulIOp::create(rewriter, loc, len, step); + res = arith::AddIOp::create(rewriter, loc, lb, res); // Reconstruct `scf.while` results, inserting final induction var value // into proper place. diff --git a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp index f829208ce8798..707ed836cf63e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp @@ -96,7 +96,7 @@ FailureOr mlir::scf::wrapWhileLoopInZeroTripCheck( condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); }); // Create rotated while loop. - auto newLoopOp = rewriter.create( + auto newLoopOp = scf::WhileOp::create(rewriter, whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs, [&](OpBuilder &builder, Location loc, ValueRange args) { // Rotate and move the loop body into before block. @@ -109,21 +109,21 @@ FailureOr mlir::scf::wrapWhileLoopInZeroTripCheck( }, [&](OpBuilder &builder, Location loc, ValueRange args) { // Pass through values. - builder.create(loc, args); + scf::YieldOp::create(builder, loc, args); }); // Create zero-trip-check and move the while loop in. - auto ifOp = rewriter.create( + auto ifOp = scf::IfOp::create(rewriter, whileOp.getLoc(), clonedCondition, [&](OpBuilder &builder, Location loc) { // Then runs the while loop. rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(), builder.getInsertionPoint()); - builder.create(loc, newLoopOp.getResults()); + scf::YieldOp::create(builder, loc, newLoopOp.getResults()); }, [&](OpBuilder &builder, Location loc) { // Else returns the results from precondition. - builder.create(loc, clonedCondArgs); + scf::YieldOp::create(builder, loc, clonedCondArgs); }); rewriter.replaceOp(whileOp, ifOp); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index f4047be68ccf2..6a387ddf8908a 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -153,7 +153,7 @@ FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, originalTerminator->getOperandTypes()); auto outlinedFunc = - rewriter.create(loc, funcName, outlinedFuncType); + func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType); Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); // Merge blocks while replacing the original block operands. @@ -168,7 +168,7 @@ FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); // Explicitly set up a new ReturnOp terminator. rewriter.setInsertionPointToEnd(outlinedFuncBody); - rewriter.create(loc, originalTerminator->getResultTypes(), + func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(), originalTerminator->getOperands()); } @@ -185,7 +185,7 @@ FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, SmallVector callValues; llvm::append_range(callValues, newBlock->getArguments()); llvm::append_range(callValues, outlinedValues); - auto call = rewriter.create(loc, outlinedFunc, callValues); + auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues); if (callOp) *callOp = call; @@ -274,12 +274,12 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, assert(dividend.getType().isIntOrIndex() && "expected integer or index-typed value"); - Value divisorMinusOneCst = builder.create( + Value divisorMinusOneCst = arith::ConstantOp::create(builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); - Value divisorCst = builder.create( + Value divisorCst = arith::ConstantOp::create(builder, loc, builder.getIntegerAttr(dividend.getType(), divisor)); - Value sum = builder.create(loc, dividend, divisorMinusOneCst); - return builder.create(loc, sum, divisorCst); + Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst); + return arith::DivUIOp::create(builder, loc, sum, divisorCst); } // Build the IR that performs ceil division of a positive value by another @@ -290,11 +290,11 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, Value divisor) { assert(dividend.getType().isIntOrIndex() && "expected integer or index-typed value"); - Value cstOne = builder.create( + Value cstOne = arith::ConstantOp::create(builder, loc, builder.getOneAttr(dividend.getType())); - Value divisorMinusOne = builder.create(loc, divisor, cstOne); - Value sum = builder.create(loc, dividend, divisorMinusOne); - return builder.create(loc, sum, divisor); + Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne); + Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne); + return arith::DivUIOp::create(builder, loc, sum, divisor); } /// Returns the trip count of `forOp` if its' low bound, high bound and step are @@ -404,7 +404,7 @@ FailureOr mlir::loopUnrollByFactor( // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. generateEpilogueLoop = upperBoundUnrolledCst < ubCst; if (generateEpilogueLoop) - upperBoundUnrolled = boundsBuilder.create( + upperBoundUnrolled = arith::ConstantOp::create(boundsBuilder, loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), upperBoundUnrolledCst)); else @@ -413,7 +413,7 @@ FailureOr mlir::loopUnrollByFactor( // Create constant for 'stepUnrolled'. stepUnrolled = stepCst == stepUnrolledCst ? step - : boundsBuilder.create( + : arith::ConstantOp::create(boundsBuilder, loc, boundsBuilder.getIntegerAttr( step.getType(), stepUnrolledCst)); } else { @@ -423,22 +423,22 @@ FailureOr mlir::loopUnrollByFactor( auto lowerBound = forOp.getLowerBound(); auto upperBound = forOp.getUpperBound(); Value diff = - boundsBuilder.create(loc, upperBound, lowerBound); + arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound); Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); - Value unrollFactorCst = boundsBuilder.create( + Value unrollFactorCst = arith::ConstantOp::create(boundsBuilder, loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); Value tripCountRem = - boundsBuilder.create(loc, tripCount, unrollFactorCst); + arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst); // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) Value tripCountEvenMultiple = - boundsBuilder.create(loc, tripCount, tripCountRem); + arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem); // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step - upperBoundUnrolled = boundsBuilder.create( + upperBoundUnrolled = arith::AddIOp::create(boundsBuilder, loc, lowerBound, - boundsBuilder.create(loc, tripCountEvenMultiple, step)); + arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step)); // Scale 'step' by 'unrollFactor'. stepUnrolled = - boundsBuilder.create(loc, step, unrollFactorCst); + arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst); } UnrolledLoopInfo resultLoops; @@ -474,11 +474,11 @@ FailureOr mlir::loopUnrollByFactor( forOp.getBody(), forOp.getInductionVar(), unrollFactor, [&](unsigned i, Value iv, OpBuilder b) { // iv' = iv + step * i; - auto stride = b.create( + auto stride = arith::MulIOp::create(b, loc, step, - b.create(loc, + arith::ConstantOp::create(b, loc, b.getIntegerAttr(iv.getType(), i))); - return b.create(loc, iv, stride); + return arith::AddIOp::create(b, loc, iv, stride); }, annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. @@ -781,13 +781,13 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, if (!isStepOne) { Value origStepValue = getValueOrCreateConstantIntOp(rewriter, loc, origStep); - scaled = rewriter.create(loc, normalizedIv, origStepValue); + scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue); preserve.insert(scaled.getDefiningOp()); } denormalizedIv = scaled; if (!isZeroBased) { Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb); - denormalizedIv = rewriter.create(loc, scaled, origLbValue); + denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue); preserve.insert(denormalizedIv.getDefiningOp()); } @@ -824,7 +824,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, continue; if (productOf) productOf = - rewriter.create(loc, productOf.value(), v).getResult(); + arith::MulIOp::create(rewriter, loc, productOf.value(), v).getResult(); else productOf = v; } @@ -851,7 +851,7 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc, if (linearizedIv.getType().isIndex()) { Operation *delinearizedOp = - rewriter.create(loc, linearizedIv, + affine::AffineDelinearizeIndexOp::create(rewriter, loc, linearizedIv, ubs); auto resultVals = llvm::map_to_vector( delinearizedOp->getResults(), [](OpResult r) -> Value { return r; }); @@ -874,7 +874,7 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc, if (!isUbOne.test(index)) { break; } - delinearizedIvs[index] = rewriter.create( + delinearizedIvs[index] = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(ub.getType())); numLeadingOneUbs++; } @@ -883,16 +883,16 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc, for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) { unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1; if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) { - previous = rewriter.create(loc, previous, ubs[idx + 1]); + previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]); preservedUsers.insert(previous.getDefiningOp()); } Value iv = previous; if (i != e - 1) { if (!isUbOne.test(idx)) { - iv = rewriter.create(loc, previous, ubs[idx]); + iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]); preservedUsers.insert(iv.getDefiningOp()); } else { - iv = rewriter.create( + iv = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType())); } } @@ -1093,12 +1093,12 @@ void mlir::collapseParallelLoops( // Combine iteration spaces. SmallVector lowerBounds, upperBounds, steps; - auto cst0 = rewriter.create(loc, 0); - auto cst1 = rewriter.create(loc, 1); + auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1); for (auto &sortedDimension : sortedDimensions) { - Value newUpperBound = rewriter.create(loc, 1); + Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1); for (auto idx : sortedDimension) { - newUpperBound = rewriter.create( + newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound, normalizedUpperBounds[idx]); } lowerBounds.push_back(cst0); @@ -1112,7 +1112,7 @@ void mlir::collapseParallelLoops( // value. The remainders then determine based on that range, which iteration // of the original induction value this represents. This is a normalized value // that is un-normalized already by the previous logic. - auto newPloop = rewriter.create( + auto newPloop = scf::ParallelOp::create(rewriter, loc, lowerBounds, upperBounds, steps, [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { @@ -1123,14 +1123,14 @@ void mlir::collapseParallelLoops( unsigned idx = combinedDimensions[i][j]; // Determine the current induction value's current loop iteration - Value iv = insideBuilder.create( + Value iv = arith::RemSIOp::create(insideBuilder, loc, previous, normalizedUpperBounds[idx]); replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, loops.getRegion()); // Remove the effect of the current induction value to prepare for // the next value. - previous = insideBuilder.create( + previous = arith::DivSIOp::create(insideBuilder, loc, previous, normalizedUpperBounds[idx]); } @@ -1241,7 +1241,7 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, auto iv = forOp.getInductionVar(); OpBuilder b(forOp); - forOp.setStep(b.create(forOp.getLoc(), originalStep, factor)); + forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor)); Loops innerLoops; for (auto t : targets) { @@ -1251,12 +1251,12 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, // Insert newForOp before the terminator of `t`. auto b = OpBuilder::atBlockTerminator((t.getBody())); - Value stepped = b.create(t.getLoc(), iv, forOp.getStep()); + Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep()); Value ub = - b.create(t.getLoc(), forOp.getUpperBound(), stepped); + arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped); // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. - auto newForOp = b.create(t.getLoc(), iv, ub, originalStep); + auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep); newForOp.getBody()->getOperations().splice( newForOp.getBody()->getOperations().begin(), t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); @@ -1343,7 +1343,7 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, auto forOp = forOps[i]; OpBuilder builder(forOp); auto loc = forOp.getLoc(); - Value diff = builder.create(loc, forOp.getUpperBound(), + Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(), forOp.getLowerBound()); Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep()); Value iterationsPerBlock = @@ -1376,7 +1376,7 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, // Create a new scf.forall op after the source loop. rewriter.setInsertionPointAfter(source); - scf::ForallOp fusedLoop = rewriter.create( + scf::ForallOp fusedLoop = scf::ForallOp::create(rewriter, source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), source.getMixedStep(), fusedOuts, source.getMapping()); @@ -1429,7 +1429,7 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, // Create a new scf.for op after the source loop (with scf.yield terminator // (without arguments) only in case its init_args is empty). rewriter.setInsertionPointAfter(source); - scf::ForOp fusedLoop = rewriter.create( + scf::ForOp fusedLoop = scf::ForOp::create(rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(), source.getStep(), fusedInitArgs); @@ -1456,7 +1456,7 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, for (Value operand : source.getBody()->getTerminator()->getOperands()) yieldResults.push_back(mapping.lookupOrDefault(operand)); if (!yieldResults.empty()) - rewriter.create(source.getLoc(), yieldResults); + scf::YieldOp::create(rewriter, source.getLoc(), yieldResults); // Replace old loops by substituting their uses by results of the fused loop. rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); @@ -1487,7 +1487,7 @@ FailureOr mlir::normalizeForallOp(RewriterBase &rewriter, // Use the normalized builder since the lower bounds are always 0 and the // steps are always 1. - auto normalizedForallOp = rewriter.create( + auto normalizedForallOp = scf::ForallOp::create(rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); diff --git a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp index 66eed861b2bb7..48c0b1ed8e14f 100644 --- a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp +++ b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp @@ -30,14 +30,14 @@ Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value, if (auto attrValue = dyn_cast(value)) { assert(bvType == attrValue.getType() && "attribute and desired result types have to match"); - return builder.create(loc, attrValue); + return BVConstantOp::create(builder, loc, attrValue); } } // BoolType constants can materialize into smt.constant if (auto boolType = dyn_cast(type)) { if (auto attrValue = dyn_cast(value)) - return builder.create(loc, attrValue); + return BoolConstantOp::create(builder, loc, attrValue); } return nullptr; diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp index 8977a3abc125d..6af87afd8b83e 100644 --- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp +++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SMT/IR/SMTOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/APSInt.h" @@ -405,7 +406,7 @@ static void buildQuantifier( SmallVector(boundVarTypes.size(), odsState.location)); Value returnVal = bodyBuilder(odsBuilder, odsState.location, block->getArguments()); - odsBuilder.create(odsState.location, returnVal); + smt::YieldOp::create(odsBuilder, odsState.location, returnVal); } if (patternBuilder) { Region *region = odsState.addRegion(); @@ -416,7 +417,7 @@ static void buildQuantifier( SmallVector(boundVarTypes.size(), odsState.location)); ValueRange returnVals = patternBuilder(odsBuilder, odsState.location, block->getArguments()); - odsBuilder.create(odsState.location, returnVals); + smt::YieldOp::create(odsBuilder, odsState.location, returnVals); } } diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index 371456552b5b5..cacc52d0feaf0 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -391,7 +391,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. - builder.create(getLoc()); + spirv::MergeOp::create(builder, getLoc()); } //===----------------------------------------------------------------------===// @@ -543,7 +543,7 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) { builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. - builder.create(getLoc()); + spirv::MergeOp::create(builder, getLoc()); } SelectionOp @@ -551,7 +551,7 @@ SelectionOp::createIfThen(Location loc, Value condition, function_ref thenBody, OpBuilder &builder) { auto selectionOp = - builder.create(loc, spirv::SelectionControl::None); + spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(builder); Block *mergeBlock = selectionOp.getMergeBlock(); @@ -562,14 +562,14 @@ SelectionOp::createIfThen(Location loc, Value condition, OpBuilder::InsertionGuard guard(builder); thenBlock = builder.createBlock(mergeBlock); thenBody(builder); - builder.create(loc, mergeBlock); + spirv::BranchOp::create(builder, loc, mergeBlock); } // Build the header block. { OpBuilder::InsertionGuard guard(builder); builder.createBlock(thenBlock); - builder.create( + spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock, /*trueArguments=*/ArrayRef(), mergeBlock, /*falseArguments=*/ArrayRef()); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 03af61c81ae6c..b1bf37092d2b6 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -179,16 +179,16 @@ struct IAddCarryFold final : OpRewritePattern { return failure(); Value addsVal = - rewriter.create(loc, constituentType, adds); + spirv::ConstantOp::create(rewriter, loc, constituentType, adds); Value carrysVal = - rewriter.create(loc, constituentType, carrys); + spirv::ConstantOp::create(rewriter, loc, constituentType, carrys); // Create empty struct - Value undef = rewriter.create(loc, op.getType()); + Value undef = spirv::UndefOp::create(rewriter, loc, op.getType()); // Fill in adds at id 0 Value intermediate = - rewriter.create(loc, addsVal, undef, 0); + spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0); // Fill in carrys at id 1 rewriter.replaceOpWithNewOp(op, carrysVal, intermediate, 1); @@ -261,16 +261,16 @@ struct MulExtendedFold final : OpRewritePattern { return failure(); Value lowBitsVal = - rewriter.create(loc, constituentType, lowBits); + spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits); Value highBitsVal = - rewriter.create(loc, constituentType, highBits); + spirv::ConstantOp::create(rewriter, loc, constituentType, highBits); // Create empty struct - Value undef = rewriter.create(loc, op.getType()); + Value undef = spirv::UndefOp::create(rewriter, loc, op.getType()); // Fill in lowBits at id 0 Value intermediate = - rewriter.create(loc, lowBitsVal, undef, 0); + spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0); // Fill in highBits at id 1 rewriter.replaceOpWithNewOp(op, highBitsVal, intermediate, 1); @@ -1310,10 +1310,10 @@ struct ConvertSelectionOpToSelect final : OpRewritePattern { auto storeOpAttributes = cast(trueBlock->front())->getAttrs(); - auto selectOp = rewriter.create( + auto selectOp = spirv::SelectOp::create(rewriter, selectionOp.getLoc(), trueValue.getType(), brConditionalOp.getCondition(), trueValue, falseValue); - rewriter.create(selectOp.getLoc(), ptrValue, + spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue, selectOp.getResult(), storeOpAttributes); // `spirv.mlir.selection` is not needed anymore. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 88c7adf3dfcb3..e85ea3455248f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -944,12 +944,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; - return builder.create(loc, type, value); + return spirv::ConstantOp::create(builder, loc, type, value); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index d8dfe164458e2..8ba2af581e00d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "SPIRVParsingUtils.h" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index eb2974d62fdd1..f1299c8a926cd 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -653,25 +653,25 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) - return builder.create(loc, type, + return spirv::ConstantOp::create(builder, loc, type, builder.getBoolAttr(false)); - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } if (auto floatType = llvm::dyn_cast(type)) { - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, builder.getFloatAttr(floatType, 0.0)); } if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (llvm::isa(elemType)) { - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0).getValue())); } if (llvm::isa(elemType)) { - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 0.0).getValue())); @@ -686,25 +686,25 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) - return builder.create(loc, type, + return spirv::ConstantOp::create(builder, loc, type, builder.getBoolAttr(true)); - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } if (auto floatType = llvm::dyn_cast(type)) { - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, builder.getFloatAttr(floatType, 1.0)); } if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (llvm::isa(elemType)) { - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1).getValue())); } if (llvm::isa(elemType)) { - return builder.create( + return spirv::ConstantOp::create(builder, loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 1.0).getValue())); @@ -1886,7 +1886,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); - builder.create(wrappedOp->getLoc(), wrappedOp->getResult(0)); + spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0)); result.location = wrappedOp->getLoc(); result.addTypes(wrappedOp->getResult(0).getType()); diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp index 71122f8e20512..4e53cef31d6ae 100644 --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -109,7 +109,7 @@ OwningOpRef combine(ArrayRef inputModules, } } - auto combinedModule = combinedModuleBuilder.create( + auto combinedModule = spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(), addressingModel, memoryModel, vceTriple); combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 6fd20466e36e3..735af7ab80ef3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -71,7 +71,7 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, varType = spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); - return builder.create( + return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(), abiInfo.getBinding()); } @@ -147,7 +147,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, return funcOp.emitRemark("lower entry point failure: could not select " "execution model based on 'spirv.target_env'"); - builder.create(funcOp.getLoc(), *executionModel, funcOp, + spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp, interfaceVars); // Specifies the spirv.ExecutionModeOp. @@ -155,7 +155,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, std::optional> caps = spirv::getCapabilities(spirv::ExecutionMode::LocalSize); if (!caps || targetEnv.allows(*caps)) { - builder.create(funcOp.getLoc(), funcOp, + spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, workgroupSizeAttr.asArrayRef()); // Erase workgroup size. @@ -168,7 +168,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, std::optional> caps = spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize); if (!caps || targetEnv.allows(*caps)) { - builder.create(funcOp.getLoc(), funcOp, + spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp, spirv::ExecutionMode::SubgroupSize, *subgroupSize); // Erase subgroup size. @@ -181,7 +181,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, std::optional> caps = spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve); if (!caps || targetEnv.allows(*caps)) { - builder.create( + spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp, spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth); // Erase target width. @@ -260,7 +260,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( // Insert spirv::AddressOf and spirv::AccessChain operations. Value replacement = - rewriter.create(funcOp.getLoc(), var); + spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var); // Check if the arg is a scalar or vector type. In that case, the value // needs to be loaded into registers. // TODO: This is loading value of the scalar into registers @@ -270,9 +270,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( if (cast(argType.value()).isScalarOrVector()) { auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); - auto loadPtr = rewriter.create( + auto loadPtr = spirv::AccessChainOp::create(rewriter, funcOp.getLoc(), replacement, zero.getConstant()); - replacement = rewriter.create(funcOp.getLoc(), loadPtr); + replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr); } signatureConverter.remapInput(argType.index(), replacement); } @@ -309,7 +309,7 @@ void LowerABIAttributesPass::runOnOperation() { ValueRange inputs, Location loc) { if (inputs.size() != 1 || !isa(inputs[0].getType())) return Value(); - return builder.create(loc, type, inputs[0]).getResult(); + return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult(); }); RewritePatternSet patterns(context); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp index 2e31172ab940b..637be91a82e33 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -66,7 +66,7 @@ void RewriteInsertsPass::runOnOperation() { operands.push_back(insertionOp.getObject()); OpBuilder builder(lastCompositeInsertOp); - auto compositeConstructOp = builder.create( + auto compositeConstructOp = spirv::CompositeConstructOp::create(builder, location, compositeType, operands); lastCompositeInsertOp.replaceAllUsesWith( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 1e7bb046d3752..4e398ca92be0e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -672,21 +672,21 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, Location loc) { // We can only cast one value in SPIR-V. if (inputs.size() != 1) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } Value input = inputs.front(); // Only support integer types for now. Floating point types to be implemented. if (!isa(type)) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } auto inputType = cast(input.getType()); auto scalarType = dyn_cast(type); if (!scalarType) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } @@ -694,14 +694,14 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, // truncating to go back so we don't need to worry about the signedness. // For extension, we cannot have enough signal here to decide which op to use. if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } // Boolean values would need to use different ops than normal integer values. if (type.isInteger(1)) { Value one = spirv::ConstantOp::getOne(inputType, loc, builder); - return builder.create(loc, input, one); + return spirv::IEqualOp::create(builder, loc, input, one); } // Check that the source integer type is supported by the environment. @@ -711,7 +711,7 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, scalarType.getCapabilities(caps); if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || failed(checkExtensionRequirements(type, targetEnv, exts))) { - auto castOp = builder.create(loc, type, inputs); + auto castOp = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); } @@ -719,9 +719,9 @@ static Value castToSourceType(const spirv::TargetEnv &targetEnv, // care about signedness here. Still try to use a corresponding op for better // consistency though. if (type.isSignedInteger()) { - return builder.create(loc, type, input); + return spirv::SConvertOp::create(builder, loc, type, input); } - return builder.create(loc, type, input); + return spirv::UConvertOp::create(builder, loc, type, input); } //===----------------------------------------------------------------------===// @@ -773,7 +773,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = - builder.create(loc, ptrType, name, builtin); + spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin); break; } case spirv::BuiltIn::SubgroupId: @@ -784,7 +784,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, spirv::PointerType::get(integerType, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = - builder.create(loc, ptrType, name, builtin); + spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin); break; } default: @@ -845,7 +845,7 @@ getOrInsertPushConstantVariable(Location loc, Block &block, auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); auto type = getPushConstantStorageType(elementCount, builder, indexType); const char *name = "__push_constant_var__"; - return builder.create(loc, type, name, + return spirv::GlobalVariableOp::create(builder, loc, type, name, /*initializer=*/nullptr); } @@ -882,7 +882,7 @@ struct FuncOpConversion final : OpConversionPattern { } // Create the converted spirv.func op. - auto newFuncOp = rewriter.create( + auto newFuncOp = spirv::FuncOp::create(rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), resultType ? TypeRange(resultType) @@ -922,7 +922,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern { } // Create a new func op with the original type and copy the function body. - auto newFuncOp = rewriter.create(funcOp.getLoc(), + auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(), funcOp.getName(), fnType); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); @@ -957,7 +957,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern { auto origVecType = dyn_cast(origType); if (!origVecType) { // We need a placeholder for the old argument that will be erased later. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, origType, rewriter.getZeroAttr(origType)); rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); tmpOps.insert({result.getDefiningOp(), newInputNo}); @@ -970,7 +970,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern { auto targetShape = getTargetShape(origVecType); if (!targetShape) { // We need a placeholder for the old argument that will be erased later. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, origType, rewriter.getZeroAttr(origType)); rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); tmpOps.insert({result.getDefiningOp(), newInputNo}); @@ -985,11 +985,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern { llvm::to_vector_of(origVecType.getShape()); // Prepare the result vector. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType)); ++newOpCount; // Prepare the placeholder for the new arguments that will be added later. - Value dummy = rewriter.create( + Value dummy = arith::ConstantOp::create(rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType)); ++newOpCount; @@ -998,7 +998,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern { SmallVector newTypes; for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape)) { - result = rewriter.create( + result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy, result, offsets, strides); newTypes.push_back(unrolledType); unrolledInputNums.push_back(newInputNo); @@ -1112,12 +1112,12 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { Value returnValue = returnOp.getOperand(origResultNo); for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape)) { - Value result = rewriter.create( + Value result = vector::ExtractStridedSliceOp::create(rewriter, loc, returnValue, offsets, extractShape, strides); if (originalShape.size() > 1) { SmallVector extractIndices(originalShape.size() - 1, 0); result = - rewriter.create(loc, result, extractIndices); + vector::ExtractOp::create(rewriter, loc, result, extractIndices); } newOperands.push_back(result); newTypes.push_back(unrolledType); @@ -1135,7 +1135,7 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { // Replace the return op using the new operands. This will automatically // update the entry block as well. rewriter.replaceOp(returnOp, - rewriter.create(loc, newOperands)); + func::ReturnOp::create(rewriter, loc, newOperands)); return success(); } @@ -1160,8 +1160,8 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), builtin, integerType, builder, prefix, suffix); - Value ptr = builder.create(op->getLoc(), varOp); - return builder.create(op->getLoc(), ptr); + Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp); + return spirv::LoadOp::create(builder, op->getLoc(), ptr); } //===----------------------------------------------------------------------===// @@ -1182,12 +1182,12 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, loc, parent->getRegion(0).front(), elementCount, builder, integerType); Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); - Value offsetOp = builder.create( + Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType, builder.getI32IntegerAttr(offset)); - auto addrOp = builder.create(loc, varOp); - auto acOp = builder.create( + auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp); + auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp})); - return builder.create(loc, acOp); + return spirv::LoadOp::create(builder, loc, acOp); } //===----------------------------------------------------------------------===// @@ -1247,7 +1247,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, linearizedIndices.push_back( linearizeIndex(indices, strides, offset, indexType, loc, builder)); } - return builder.create(loc, basePtr, linearizedIndices); + return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices); } Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, @@ -1278,10 +1278,10 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, cast(basePtr.getType()).getPointeeType(); if (isa(pointeeType)) { linearizedIndices.push_back(linearIndex); - return builder.create(loc, basePtr, + return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices); } - return builder.create(loc, basePtr, linearIndex, + return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex, linearizedIndices); } @@ -1468,7 +1468,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, }); addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - auto cast = builder.create(loc, type, inputs); + auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return cast.getResult(0); }); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index af1cf2a1373e3..17f79780afa20 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -64,16 +64,16 @@ static Value lowerExtendedMultiplication(Operation *mulOp, // and 4 additions after constant folding. // - With sign-extended arguments, we end up emitting 8 multiplications and // and 12 additions after CSE. - Value cstLowMask = rewriter.create( + Value cstLowMask = ConstantOp::create(rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { - return rewriter.create(loc, val, cstLowMask); + return BitwiseAndOp::create(rewriter, loc, val, cstLowMask); }; - Value cst16 = rewriter.create(loc, lhs.getType(), + Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, 16)); auto getHighDigit = [&rewriter, loc, cst16](Value val) { - return rewriter.create(loc, val, cst16); + return ShiftRightLogicalOp::create(rewriter, loc, val, cst16); }; auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { @@ -82,10 +82,10 @@ static Value lowerExtendedMultiplication(Operation *mulOp, // fine. We do not have to introduce an extra constant since any // value in [15, 32) would do. return getHighDigit( - rewriter.create(loc, val, cst16)); + ShiftRightArithmeticOp::create(rewriter, loc, val, cst16)); }; - Value cst0 = rewriter.create(loc, lhs.getType(), + Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, 0)); Value lhsLow = getLowDigit(lhs); @@ -108,7 +108,7 @@ static Value lowerExtendedMultiplication(Operation *mulOp, continue; Value &thisResDigit = resultDigits[i + j]; - Value mul = rewriter.create(loc, lhsDigit, rhsDigit); + Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit); Value current = rewriter.createOrFold(loc, thisResDigit, mul); thisResDigit = getLowDigit(current); @@ -122,13 +122,13 @@ static Value lowerExtendedMultiplication(Operation *mulOp, } auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { - Value highBits = rewriter.create(loc, high, cst16); - return rewriter.create(loc, low, highBits); + Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16); + return BitwiseOrOp::create(rewriter, loc, low, highBits); }; Value low = combineDigits(resultDigits[0], resultDigits[1]); Value high = combineDigits(resultDigits[2], resultDigits[3]); - return rewriter.create( + return CompositeConstructOp::create(rewriter, loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); } @@ -185,16 +185,16 @@ struct ExpandAddCarryPattern final : OpRewritePattern { llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); Value one = - rewriter.create(loc, argTy, getScalarOrSplatAttr(argTy, 1)); + ConstantOp::create(rewriter, loc, argTy, getScalarOrSplatAttr(argTy, 1)); Value zero = - rewriter.create(loc, argTy, getScalarOrSplatAttr(argTy, 0)); + ConstantOp::create(rewriter, loc, argTy, getScalarOrSplatAttr(argTy, 0)); // Calculate the carry by checking if the addition resulted in an overflow. - Value out = rewriter.create(loc, lhs, rhs); - Value cmp = rewriter.create(loc, out, lhs); - Value carry = rewriter.create(loc, cmp, one, zero); + Value out = IAddOp::create(rewriter, loc, lhs, rhs); + Value cmp = ULessThanOp::create(rewriter, loc, out, lhs); + Value carry = SelectOp::create(rewriter, loc, cmp, one, zero); - Value add = rewriter.create( + Value add = CompositeConstructOp::create(rewriter, loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); rewriter.replaceOp(op, add); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index 07cf26926a1df..7fe2bd79403f9 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -384,13 +384,13 @@ struct ConvertAccessChain : public ConvertAliasResource { Type indexType = oldIndex.getType(); int ratio = dstNumBytes / srcNumBytes; - auto ratioValue = rewriter.create( + auto ratioValue = spirv::ConstantOp::create(rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); indices.back() = - rewriter.create(loc, indexType, oldIndex, ratioValue); + spirv::SDivOp::create(rewriter, loc, indexType, oldIndex, ratioValue); indices.push_back( - rewriter.create(loc, indexType, oldIndex, ratioValue)); + spirv::SModOp::create(rewriter, loc, indexType, oldIndex, ratioValue)); rewriter.replaceOpWithNewOp( acOp, adaptor.getBasePtr(), indices); @@ -411,11 +411,11 @@ struct ConvertAccessChain : public ConvertAliasResource { Type indexType = oldIndex.getType(); int ratio = srcNumBytes / dstNumBytes; - auto ratioValue = rewriter.create( + auto ratioValue = spirv::ConstantOp::create(rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); indices.back() = - rewriter.create(loc, indexType, oldIndex, ratioValue); + spirv::IMulOp::create(rewriter, loc, indexType, oldIndex, ratioValue); rewriter.replaceOpWithNewOp( acOp, adaptor.getBasePtr(), indices); @@ -439,14 +439,14 @@ struct ConvertLoad : public ConvertAliasResource { auto dstElemType = cast(dstPtrType.getPointeeType()); Location loc = loadOp.getLoc(); - auto newLoadOp = rewriter.create(loc, adaptor.getPtr()); + auto newLoadOp = spirv::LoadOp::create(rewriter, loc, adaptor.getPtr()); if (srcElemType == dstElemType) { rewriter.replaceOp(loadOp, newLoadOp->getResults()); return success(); } if (areSameBitwidthScalarType(srcElemType, dstElemType)) { - auto castOp = rewriter.create(loc, srcElemType, + auto castOp = spirv::BitcastOp::create(rewriter, loc, srcElemType, newLoadOp.getValue()); rewriter.replaceOp(loadOp, castOp->getResults()); @@ -479,14 +479,14 @@ struct ConvertLoad : public ConvertAliasResource { auto indices = llvm::to_vector<4>(acOp.getIndices()); for (int i = 1; i < ratio; ++i) { // Load all subsequent components belonging to this element. - indices.back() = rewriter.create( + indices.back() = spirv::IAddOp::create(rewriter, loc, i32Type, indices.back(), oneValue); - auto componentAcOp = rewriter.create( + auto componentAcOp = spirv::AccessChainOp::create(rewriter, loc, acOp.getBasePtr(), indices); // Assuming little endian, this reads lower-ordered bits of the number // to lower-numbered components of the vector. components.push_back( - rewriter.create(loc, componentAcOp)); + spirv::LoadOp::create(rewriter, loc, componentAcOp)); } // Create a vector of the components and then cast back to the larger @@ -514,15 +514,15 @@ struct ConvertLoad : public ConvertAliasResource { castType = VectorType::get({count}, castType); for (Value &c : components) - c = rewriter.create(loc, castType, c); + c = spirv::BitcastOp::create(rewriter, loc, castType, c); } } - Value vectorValue = rewriter.create( + Value vectorValue = spirv::CompositeConstructOp::create(rewriter, loc, vectorType, components); if (!isa(srcElemType)) vectorValue = - rewriter.create(loc, srcElemType, vectorValue); + spirv::BitcastOp::create(rewriter, loc, srcElemType, vectorValue); rewriter.replaceOp(loadOp, vectorValue); return success(); } @@ -550,7 +550,7 @@ struct ConvertStore : public ConvertAliasResource { Location loc = storeOp.getLoc(); Value value = adaptor.getValue(); if (srcElemType != dstElemType) - value = rewriter.create(loc, dstElemType, value); + value = spirv::BitcastOp::create(rewriter, loc, dstElemType, value); rewriter.replaceOpWithNewOp(storeOp, adaptor.getPtr(), value, storeOp->getAttrs()); return success(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 5f395eef9d601..19dba20f43fa9 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -151,17 +152,17 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); if (llvm::isa(type) || isExtentTensorType(type)) - return builder.create( - loc, type, llvm::cast(value)); + return ConstShapeOp::create(builder, loc, type, + llvm::cast(value)); if (llvm::isa(type)) - return builder.create(loc, type, - llvm::cast(value)); + return ConstSizeOp::create(builder, loc, type, + llvm::cast(value)); if (llvm::isa(type)) - return builder.create(loc, type, - llvm::cast(value)); + return ConstWitnessOp::create(builder, loc, type, + llvm::cast(value)); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -316,8 +317,8 @@ struct AssumingOpRemoveUnusedResults : public OpRewritePattern { auto newYieldOp = rewriter.replaceOpWithNewOp(yieldOp, newYieldOperands); rewriter.setInsertionPoint(op); - auto newOp = rewriter.create( - op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); + auto newOp = AssumingOp::create( + rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); newOp.getDoRegion().takeBody(op.getDoRegion()); // Use the new results to replace the previously used ones. @@ -385,7 +386,7 @@ void AssumingOp::build( // Build body. SmallVector yieldValues = bodyBuilder(builder, result.location); - builder.create(result.location, yieldValues); + AssumingYieldOp::create(builder, result.location, yieldValues); SmallVector assumingTypes; for (Value v : yieldValues) @@ -736,13 +737,13 @@ struct BroadcastForwardSingleOperandPattern if (replacement.getType() != op.getType()) { auto loc = op.getLoc(); if (llvm::isa(op.getType())) { - replacement = rewriter.create(loc, replacement); + replacement = FromExtentTensorOp::create(rewriter, loc, replacement); } else { assert(!llvm::isa(op.getType()) && !llvm::isa(replacement.getType()) && "expect extent tensor cast"); replacement = - rewriter.create(loc, op.getType(), replacement); + tensor::CastOp::create(rewriter, loc, op.getType(), replacement); } } @@ -780,9 +781,9 @@ struct BroadcastFoldConstantOperandsPattern auto foldedConstantOperandsTy = RankedTensorType::get( {static_cast(foldedConstantShape.size())}, rewriter.getIndexType()); - newShapeOperands.push_back(rewriter.create( - op.getLoc(), foldedConstantOperandsTy, - rewriter.getIndexTensorAttr(foldedConstantShape))); + newShapeOperands.push_back( + ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy, + rewriter.getIndexTensorAttr(foldedConstantShape))); rewriter.replaceOpWithNewOp(op, op.getType(), newShapeOperands); return success(); @@ -845,9 +846,9 @@ struct BroadcastConcretizeResultTypePattern } } - auto newOp = rewriter.create( - op.getLoc(), getExtentTensorType(getContext(), maxRank), - op.getShapes()); + auto newOp = BroadcastOp::create(rewriter, op.getLoc(), + getExtentTensorType(getContext(), maxRank), + op.getShapes()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } @@ -1338,7 +1339,8 @@ std::optional GetExtentOp::getConstantDim() { } OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) { - auto elements = llvm::dyn_cast_if_present(adaptor.getShape()); + auto elements = + llvm::dyn_cast_if_present(adaptor.getShape()); if (!elements) return nullptr; std::optional dim = getConstantDim(); @@ -1354,11 +1356,11 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, auto loc = result.location; auto dimAttr = builder.getIndexAttr(dim); if (llvm::isa(shape.getType())) { - Value dim = builder.create(loc, dimAttr); + Value dim = ConstSizeOp::create(builder, loc, dimAttr); build(builder, result, builder.getType(), shape, dim); } else { - Value dim = - builder.create(loc, builder.getIndexType(), dimAttr); + Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(), + dimAttr); build(builder, result, builder.getIndexType(), shape, dim); } } @@ -1480,7 +1482,8 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { //===----------------------------------------------------------------------===// OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) { - auto shape = llvm::dyn_cast_if_present(adaptor.getShape()); + auto shape = + llvm::dyn_cast_if_present(adaptor.getShape()); if (!shape) return {}; int64_t rank = shape.getNumElements(); @@ -1708,8 +1711,8 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern { rewriter.getIndexTensorAttr(type.getShape())) .getResult(); if (constShape.getType() != op.getResult().getType()) - constShape = rewriter.create( - loc, op.getResult().getType(), constShape); + constShape = tensor::CastOp::create(rewriter, loc, + op.getResult().getType(), constShape); rewriter.replaceOp(op, constShape); return success(); } @@ -1717,13 +1720,13 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern { // Canonicalize // -// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor +// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor) -> +// tensor<*xf32> %1 = shape.shape_of %0 : tensor<*xf32> -> tensor // // to // -// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// %1 = %shape +// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor) -> +// tensor<*xf32> %1 = %shape // struct ShapeOfFromReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1751,10 +1754,11 @@ struct ShapeOfFromReshape : public OpRewritePattern { if (opTensorTy != shapeTensorTy) { if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) - shape = rewriter.create(op.getLoc(), opTensorTy, shape); - else if (!isExtentTensorType(shapeTensorTy)) shape = - rewriter.create(op.getLoc(), opTensorTy, shape); + tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape); + else if (!isExtentTensorType(shapeTensorTy)) + shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy, + shape); } rewriter.replaceOp(op, shape); @@ -1895,8 +1899,9 @@ LogicalResult SplitAtOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { if (!adaptor.getOperand() || !adaptor.getIndex()) return failure(); - auto shapeVec = llvm::to_vector<6>( - llvm::cast(adaptor.getOperand()).getValues()); + auto shapeVec = + llvm::to_vector<6>(llvm::cast(adaptor.getOperand()) + .getValues()); auto shape = llvm::ArrayRef(shapeVec); auto splitPoint = llvm::cast(adaptor.getIndex()).getInt(); // Verify that the split point is in the correct range. @@ -1920,8 +1925,9 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) { if (!adaptor.getInput()) return OpFoldResult(); Builder builder(getContext()); - auto shape = llvm::to_vector<6>( - llvm::cast(adaptor.getInput()).getValues()); + auto shape = + llvm::to_vector<6>(llvm::cast(adaptor.getInput()) + .getValues()); auto type = RankedTensorType::get({static_cast(shape.size())}, builder.getIndexType()); return DenseIntElementsAttr::get(type, shape); diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index 8a471c12d21e4..5533a7e914bf0 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -57,7 +57,7 @@ struct AssumingOpInterface // Create new op and move over region. TypeRange newResultTypes(yieldOp.getOperands()); - auto newOp = rewriter.create( + auto newOp = shape::AssumingOp::create(rewriter, op->getLoc(), newResultTypes, assumingOp.getWitness()); newOp.getDoRegion().takeBody(assumingOp.getRegion()); @@ -66,7 +66,7 @@ struct AssumingOpInterface SmallVector newResults; for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { if (isa(it.value())) { - newResults.push_back(rewriter.create( + newResults.push_back(bufferization::ToTensorOp::create(rewriter, assumingOp.getLoc(), it.value(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp index ae06a34b65709..522347ff2b993 100644 --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -69,7 +69,7 @@ createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, cluster.empty() ? b.getFunctionType(shape.getType(), shape.getType()) : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType()); - shape::FuncOp fnOp = b.create(loc, fnName, fnType); + shape::FuncOp fnOp = shape::FuncOp::create(b, loc, fnName, fnType); Block *block = fnOp.addEntryBlock(); b.setInsertionPointToEnd(block); IRMapping bvm; @@ -85,7 +85,7 @@ createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, llvm::SmallVector fnReturns; fnReturns.push_back(bvm.lookupOrDefault(shape)); - b.create(loc, fnReturns); + shape::ReturnOp::create(b, loc, fnReturns); fnOp.setPrivate(); return std::make_pair(fnOp, inputs); } @@ -187,7 +187,7 @@ class TensorDimOpRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::DimOp op, PatternRewriter &rewriter) const override { auto shapeOf = - rewriter.create(op.getLoc(), op.getSource()); + shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource()); rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, op.getIndex()); return success(); diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp index 121e0cc133e19..b5f6230bc2362 100644 --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -44,14 +44,14 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op, ->materializeConstant(rewriter, rewriter.getIndexAttr(1), valueType, loc) ->getResult(0); - ReduceOp reduce = rewriter.create(loc, op.getShape(), init); + ReduceOp reduce = ReduceOp::create(rewriter, loc, op.getShape(), init); // Generate reduce operator. Block *body = reduce.getBody(); OpBuilder b = OpBuilder::atBlockEnd(body); - Value product = b.create(loc, valueType, body->getArgument(1), + Value product = MulOp::create(b, loc, valueType, body->getArgument(1), body->getArgument(2)); - b.create(loc, product); + shape::YieldOp::create(b, loc, product); rewriter.replaceOp(op, reduce.getResult()); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index ea7918fa14f23..eb42335fa8c94 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -561,7 +562,8 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, SmallVector retType( dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), builder.getIndexType()); - auto transOp = builder.create(loc, retType, crds, dir, *this); + auto transOp = + CrdTranslateOp::create(builder, loc, retType, crds, dir, *this); return transOp.getOutCrds(); } @@ -1483,7 +1485,7 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, int64_t index) { - Value val = builder.create(state.location, index); + Value val = arith::ConstantIndexOp::create(builder, state.location, index); return build(builder, state, source, val); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp index 9c84f4c25866f..74842cc1f310b 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp @@ -41,7 +41,7 @@ LogicalResult sparse_tensor::detail::stageWithSortImpl( // -> sort Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true); - Value dstCOO = rewriter.create( + Value dstCOO = ReorderCOOOp::create(rewriter, loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); // -> dest. diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp index bdec43825ddc2..33249368f7ed3 100644 --- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp +++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" using namespace mlir; using namespace mlir::sparse_tensor; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index 5461987fb49d9..dce6e438d55d0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -90,13 +90,13 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, } else if (directOut) { Value mem; if (kind == SparseTensorFieldKind::PosMemRef) - mem = builder.create(loc, inputs[0], + mem = sparse_tensor::ToPositionsOp::create(builder, loc, inputs[0], lv); else if (kind == SparseTensorFieldKind::CrdMemRef) - mem = builder.create(loc, inputs[0], + mem = sparse_tensor::ToCoordinatesOp::create(builder, loc, inputs[0], lv); else - mem = builder.create(loc, inputs[0]); + mem = sparse_tensor::ToValuesOp::create(builder, loc, inputs[0]); toVals.push_back(mem); } else { ShapedType rtp = cast(t); @@ -111,7 +111,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, if (isIn) { // Assemble multiple inputs into a single sparse tensor. - auto a = builder.create(loc, rtp, inputs); + auto a = sparse_tensor::AssembleOp::create(builder, loc, rtp, inputs); toVals.push_back(a.getResult()); } else if (!directOut) { // Disassemble a single sparse input into multiple outputs. @@ -119,7 +119,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, unsigned len = retTypes.size(); retTypes.append(cntTypes); auto d = - builder.create(loc, retTypes, inputs); + sparse_tensor::DisassembleOp::create(builder, loc, retTypes, inputs); for (unsigned i = 0; i < len; i++) toVals.push_back(d.getResult(i)); } @@ -201,7 +201,7 @@ struct SparseFuncAssembler : public OpRewritePattern { OpBuilder moduleBuilder(modOp.getBodyRegion()); unsigned extra = inputTypes.size(); inputTypes.append(extraTypes); - auto func = moduleBuilder.create( + auto func = func::FuncOp::create(moduleBuilder, loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); func.setPublic(); @@ -218,14 +218,14 @@ struct SparseFuncAssembler : public OpRewritePattern { // Call the original, now private method. A subsequent inlining pass can // determine whether cloning the method body in place is worthwhile. auto org = SymbolRefAttr::get(context, wrapper); - auto call = rewriter.create(loc, funcOp.getResultTypes(), org, + auto call = func::CallOp::create(rewriter, loc, funcOp.getResultTypes(), org, inputs); // Convert outputs and return. SmallVector outputs; convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), body->getArguments(), outputs, extra, /*isIn=*/false, directOut); - rewriter.create(loc, outputs); + func::ReturnOp::create(rewriter, loc, outputs); // Finally, migrate a potential c-interface property. if (funcOp->getAttrOfType( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 0c5912bb73772..12507ff1c5251 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -94,7 +94,7 @@ static FlatSymbolRefAttr getMangledSortHelperFunc( OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(insertPoint); Location loc = insertPoint.getLoc(); - func = builder.create( + func = func::FuncOp::create(builder, loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); @@ -111,13 +111,13 @@ static void forEachIJPairInXs( uint64_t ny, function_ref bodyBuilder) { Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); - Value iOffset = builder.create(loc, args[0], cstep); - Value jOffset = builder.create(loc, args[1], cstep); + Value iOffset = arith::MulIOp::create(builder, loc, args[0], cstep); + Value jOffset = arith::MulIOp::create(builder, loc, args[1], cstep); for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { unsigned actualK = cast(xPerm.getResult(k)).getPosition(); Value ak = constantIndex(builder, loc, actualK); - Value i = builder.create(loc, ak, iOffset); - Value j = builder.create(loc, ak, jOffset); + Value i = arith::AddIOp::create(builder, loc, ak, iOffset); + Value j = arith::AddIOp::create(builder, loc, ak, jOffset); Value buffer = args[xStartIdx]; bodyBuilder(k, i, j, buffer); @@ -165,10 +165,10 @@ static void forEachIJPairInAllBuffers( static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny) { auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { - Value vi = builder.create(loc, buffer, i); - Value vj = builder.create(loc, buffer, j); - builder.create(loc, vj, buffer, i); - builder.create(loc, vi, buffer, j); + Value vi = memref::LoadOp::create(builder, loc, buffer, i); + Value vj = memref::LoadOp::create(builder, loc, buffer, j); + memref::StoreOp::create(builder, loc, vj, buffer, i); + memref::StoreOp::create(builder, loc, vi, buffer, j); }; forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); @@ -193,7 +193,7 @@ static Value createInlinedCompareImplementation( OpBuilder::InsertionGuard insertionGuard(builder); auto ifOp = cast(val.getDefiningOp()); builder.setInsertionPointAfter(ifOp); - builder.create(loc, ifOp.getResult(0)); + scf::YieldOp::create(builder, loc, ifOp.getResult(0)); } }; @@ -207,25 +207,25 @@ static Value createInlinedCompareImplementation( /// result of the comparison. static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) { - Value vi = builder.create(loc, x, i); - Value vj = builder.create(loc, x, j); + Value vi = memref::LoadOp::create(builder, loc, x, i); + Value vj = memref::LoadOp::create(builder, loc, x, j); Value res; if (isLastDim) { - res = builder.create(loc, arith::CmpIPredicate::eq, vi, vj); + res = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, vi, vj); // For 1D, we create a compare without any control flow. Otherwise, we // create YieldOp to return the result in the nested if-stmt. if (!isFirstDim) - builder.create(loc, res); + scf::YieldOp::create(builder, loc, res); } else { Value ne = - builder.create(loc, arith::CmpIPredicate::ne, vi, vj); - scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIntegerType(1), ne, /*else=*/true); // If (x[i] != x[j]). builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); Value f = constantI1(builder, loc, false); - builder.create(loc, f); + scf::YieldOp::create(builder, loc, f); // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that // checks the remaining dimensions. @@ -261,26 +261,26 @@ static Value createInlinedEqCompare(OpBuilder &builder, Location loc, static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) { - Value vi = builder.create(loc, x, i); - Value vj = builder.create(loc, x, j); + Value vi = memref::LoadOp::create(builder, loc, x, i); + Value vj = memref::LoadOp::create(builder, loc, x, j); Value res; if (isLastDim) { - res = builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + res = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj); // For 1D, we create a compare without any control flow. Otherwise, we // create YieldOp to return the result in the nested if-stmt. if (!isFirstDim) - builder.create(loc, res); + scf::YieldOp::create(builder, loc, res); } else { Value ne = - builder.create(loc, arith::CmpIPredicate::ne, vi, vj); - scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIntegerType(1), ne, /*else=*/true); // If (x[i] != x[j]). builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); Value lt = - builder.create(loc, arith::CmpIPredicate::ult, vi, vj); - builder.create(loc, lt); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj); + scf::YieldOp::create(builder, loc, lt); // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that // checks the remaining dimensions. @@ -337,17 +337,17 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; SmallVector types(2, p.getType()); // Only two types. - scf::WhileOp whileOp = builder.create( + scf::WhileOp whileOp = scf::WhileOp::create(builder, loc, types, SmallVector{args[loIdx], args[hiIdx]}); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(before); - Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, + Value cond1 = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, before->getArgument(0), before->getArgument(1)); - builder.create(loc, cond1, before->getArguments()); + scf::ConditionOp::create(builder, loc, cond1, before->getArguments()); // The after-region of the WhileOp. Block *after = @@ -357,9 +357,9 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, Value hi = after->getArgument(1); // Compute mid = (lo + hi) >> 1. Value c1 = constantIndex(builder, loc, 1); - Value mid = builder.create( - loc, builder.create(loc, lo, hi), c1); - Value midp1 = builder.create(loc, mid, c1); + Value mid = arith::ShRUIOp::create(builder, + loc, arith::AddIOp::create(builder, loc, lo, hi), c1); + Value midp1 = arith::AddIOp::create(builder, loc, mid, c1); // Compare xs[p] < xs[mid]. SmallVector compareOperands{p, mid}; @@ -372,12 +372,12 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, // hi = mid; // else // lo = mid + 1; - Value newLo = builder.create(loc, cond2, lo, midp1); - Value newHi = builder.create(loc, cond2, mid, hi); - builder.create(loc, ValueRange{newLo, newHi}); + Value newLo = arith::SelectOp::create(builder, loc, cond2, lo, midp1); + Value newHi = arith::SelectOp::create(builder, loc, cond2, mid, hi); + scf::YieldOp::create(builder, loc, ValueRange{newLo, newHi}); builder.setInsertionPointAfter(whileOp); - builder.create(loc, whileOp.getResult(0)); + func::ReturnOp::create(builder, loc, whileOp.getResult(0)); } /// Creates code to advance i in a loop based on xs[p] as follows: @@ -393,7 +393,7 @@ static std::pair createScanLoop(OpBuilder &builder, uint64_t ny, int step) { Location loc = func.getLoc(); scf::WhileOp whileOp = - builder.create(loc, TypeRange{i.getType()}, ValueRange{i}); + scf::WhileOp::create(builder, loc, TypeRange{i.getType()}, ValueRange{i}); Block *before = builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); @@ -409,14 +409,14 @@ static std::pair createScanLoop(OpBuilder &builder, } compareOperands.append(xs.begin(), xs.end()); Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); - builder.create(loc, cond, before->getArguments()); + scf::ConditionOp::create(builder, loc, cond, before->getArguments()); Block *after = builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); builder.setInsertionPointToEnd(after); Value cs = constantIndex(builder, loc, step); - i = builder.create(loc, after->getArgument(0), cs); - builder.create(loc, ValueRange{i}); + i = arith::AddIOp::create(builder, loc, after->getArgument(0), cs); + scf::YieldOp::create(builder, loc, ValueRange{i}); i = whileOp.getResult(0); builder.setInsertionPointAfter(whileOp); @@ -440,7 +440,7 @@ static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, compareOperands[0] = b; compareOperands[1] = a; Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); - scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond, /*else=*/false); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); swapOperands[0] = b; swapOperands[1] = a; @@ -517,12 +517,12 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, swapOperands.append(args.begin() + xStartIdx, args.end()); Location loc = func.getLoc(); Value c1 = constantIndex(builder, loc, 1); - Value hiP1 = builder.create(loc, hi, c1); - Value len = builder.create(loc, hiP1, lo); + Value hiP1 = arith::AddIOp::create(builder, loc, hi, c1); + Value len = arith::SubIOp::create(builder, loc, hiP1, lo); Value lenThreshold = constantIndex(builder, loc, 1000); - Value lenCond = builder.create(loc, arith::CmpIPredicate::ult, + Value lenCond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, len, lenThreshold); - scf::IfOp lenIf = builder.create(loc, lenCond, /*else=*/true); + scf::IfOp lenIf = scf::IfOp::create(builder, loc, lenCond, /*else=*/true); // When len < 1000, choose pivot from median of 3 values. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); @@ -531,13 +531,13 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, // When len >= 1000, choose pivot from median of 5 values. builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); - Value miP1 = builder.create(loc, hi, c1); - Value a = builder.create(loc, lo, miP1); + Value miP1 = arith::AddIOp::create(builder, loc, hi, c1); + Value a = arith::AddIOp::create(builder, loc, lo, miP1); // Value a is the middle between [loc, mi]. - a = builder.create(loc, a, c1); - Value b = builder.create(loc, mi, hiP1); + a = arith::ShRUIOp::create(builder, loc, a, c1); + Value b = arith::AddIOp::create(builder, loc, mi, hiP1); // Value b is the middle between [mi, hi]. - b = builder.create(loc, b, c1); + b = arith::ShRUIOp::create(builder, loc, b, c1); createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, b, hi); @@ -589,24 +589,24 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; - Value sum = builder.create(loc, lo, hi); + Value sum = arith::AddIOp::create(builder, loc, lo, hi); Value c1 = constantIndex(builder, loc, 1); - Value p = builder.create(loc, sum, c1); + Value p = arith::ShRUIOp::create(builder, loc, sum, c1); Value i = lo; - Value j = builder.create(loc, hi, c1); + Value j = arith::SubIOp::create(builder, loc, hi, c1); createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); Value trueVal = constantI1(builder, loc, true); // The value for while (true) SmallVector operands{i, j, p, trueVal}; // Exactly four values. SmallVector types{i.getType(), j.getType(), p.getType(), trueVal.getType()}; - scf::WhileOp whileOp = builder.create(loc, types, operands); + scf::WhileOp whileOp = scf::WhileOp::create(builder, loc, types, operands); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc, loc}); builder.setInsertionPointToEnd(before); - builder.create(loc, before->getArgument(3), + scf::ConditionOp::create(builder, loc, before->getArgument(3), before->getArguments()); // The after-region of the WhileOp. @@ -629,70 +629,70 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, // If i < j: Value cond = - builder.create(loc, arith::CmpIPredicate::ult, i, j); - scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, i, j); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector swapOperands{i, j}; swapOperands.append(args.begin() + xStartIdx, args.end()); createSwap(builder, loc, swapOperands, xPerm, ny); // If the pivot is moved, update p with the new pivot. Value icond = - builder.create(loc, arith::CmpIPredicate::eq, i, p); - scf::IfOp ifOpI = builder.create(loc, TypeRange{p.getType()}, + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, i, p); + scf::IfOp ifOpI = scf::IfOp::create(builder, loc, TypeRange{p.getType()}, icond, /*else=*/true); builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); - builder.create(loc, ValueRange{j}); + scf::YieldOp::create(builder, loc, ValueRange{j}); builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); Value jcond = - builder.create(loc, arith::CmpIPredicate::eq, j, p); - scf::IfOp ifOpJ = builder.create(loc, TypeRange{p.getType()}, + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, j, p); + scf::IfOp ifOpJ = scf::IfOp::create(builder, loc, TypeRange{p.getType()}, jcond, /*else=*/true); builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); - builder.create(loc, ValueRange{i}); + scf::YieldOp::create(builder, loc, ValueRange{i}); builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); - builder.create(loc, ValueRange{p}); + scf::YieldOp::create(builder, loc, ValueRange{p}); builder.setInsertionPointAfter(ifOpJ); - builder.create(loc, ifOpJ.getResults()); + scf::YieldOp::create(builder, loc, ifOpJ.getResults()); builder.setInsertionPointAfter(ifOpI); Value compareEqIJ = - builder.create(loc, iCompareEq, jCompareEq); - scf::IfOp ifOp2 = builder.create( + arith::AndIOp::create(builder, loc, iCompareEq, jCompareEq); + scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - Value i2 = builder.create(loc, i, c1); - Value j2 = builder.create(loc, j, c1); - builder.create(loc, ValueRange{i2, j2}); + Value i2 = arith::AddIOp::create(builder, loc, i, c1); + Value j2 = arith::SubIOp::create(builder, loc, j, c1); + scf::YieldOp::create(builder, loc, ValueRange{i2, j2}); builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); - builder.create(loc, ValueRange{i, j}); + scf::YieldOp::create(builder, loc, ValueRange{i, j}); builder.setInsertionPointAfter(ifOp2); - builder.create( + scf::YieldOp::create(builder, loc, ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), /*cont=*/constantI1(builder, loc, true)}); // False branch for if i < j (i.e., i >= j): builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - p = builder.create(loc, j, + p = arith::AddIOp::create(builder, loc, j, constantOne(builder, loc, j.getType())); - builder.create( + scf::YieldOp::create(builder, loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); // Return for the whileOp. builder.setInsertionPointAfter(ifOp); - builder.create(loc, ifOp.getResults()); + scf::YieldOp::create(builder, loc, ifOp.getResults()); // Return for the function. builder.setInsertionPointAfter(whileOp); - builder.create(loc, whileOp.getResult(2)); + func::ReturnOp::create(builder, loc, whileOp.getResult(2)); } /// Computes (n-2)/n, assuming n has index type. static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n) { Value i2 = constantIndex(builder, loc, 2); - Value res = builder.create(loc, n, i2); + Value res = arith::SubIOp::create(builder, loc, n, i2); Value i1 = constantIndex(builder, loc, 1); - return builder.create(loc, res, i1); + return arith::ShRUIOp::create(builder, loc, res, i1); } /// Creates a function to heapify the subtree with root `start` within the full @@ -743,16 +743,16 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, // If (n >= 2). Value c2 = constantIndex(builder, loc, 2); Value condN = - builder.create(loc, arith::CmpIPredicate::uge, n, c2); - scf::IfOp ifN = builder.create(loc, condN, /*else=*/false); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, n, c2); + scf::IfOp ifN = scf::IfOp::create(builder, loc, condN, /*else=*/false); builder.setInsertionPointToStart(&ifN.getThenRegion().front()); - Value child = builder.create(loc, start, first); + Value child = arith::SubIOp::create(builder, loc, start, first); // If ((n-2)/2 >= child). Value t = createSubTwoDividedByTwo(builder, loc, n); Value condNc = - builder.create(loc, arith::CmpIPredicate::uge, t, child); - scf::IfOp ifNc = builder.create(loc, condNc, /*else=*/false); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child); + scf::IfOp ifNc = scf::IfOp::create(builder, loc, condNc, /*else=*/false); builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); Value c1 = constantIndex(builder, loc, 1); @@ -768,32 +768,32 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, // if (child+1 < n && data[childIndex] < data[childIndex+1]) // childIndex ++; child ++ // Right child is bigger. auto getLargerChild = [&](Value r) -> std::pair { - Value lChild = builder.create(loc, r, c1); - lChild = builder.create(loc, lChild, c1); - Value lChildIdx = builder.create(loc, lChild, first); - Value rChild = builder.create(loc, lChild, c1); - Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, + Value lChild = arith::ShLIOp::create(builder, loc, r, c1); + lChild = arith::AddIOp::create(builder, loc, lChild, c1); + Value lChildIdx = arith::AddIOp::create(builder, loc, lChild, first); + Value rChild = arith::AddIOp::create(builder, loc, lChild, c1); + Value cond1 = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, rChild, n); SmallVector ifTypes(2, r.getType()); scf::IfOp if1 = - builder.create(loc, ifTypes, cond1, /*else=*/true); + scf::IfOp::create(builder, loc, ifTypes, cond1, /*else=*/true); builder.setInsertionPointToStart(&if1.getThenRegion().front()); - Value rChildIdx = builder.create(loc, rChild, first); + Value rChildIdx = arith::AddIOp::create(builder, loc, rChild, first); // Compare data[left] < data[right]. compareOperands[0] = lChildIdx; compareOperands[1] = rChildIdx; Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); scf::IfOp if2 = - builder.create(loc, ifTypes, cond2, /*else=*/true); + scf::IfOp::create(builder, loc, ifTypes, cond2, /*else=*/true); builder.setInsertionPointToStart(&if2.getThenRegion().front()); - builder.create(loc, ValueRange{rChild, rChildIdx}); + scf::YieldOp::create(builder, loc, ValueRange{rChild, rChildIdx}); builder.setInsertionPointToStart(&if2.getElseRegion().front()); - builder.create(loc, ValueRange{lChild, lChildIdx}); + scf::YieldOp::create(builder, loc, ValueRange{lChild, lChildIdx}); builder.setInsertionPointAfter(if2); - builder.create(loc, if2.getResults()); + scf::YieldOp::create(builder, loc, if2.getResults()); builder.setInsertionPointToStart(&if1.getElseRegion().front()); - builder.create(loc, ValueRange{lChild, lChildIdx}); + scf::YieldOp::create(builder, loc, ValueRange{lChild, lChildIdx}); builder.setInsertionPointAfter(if1); return std::make_pair(if1.getResult(0), if1.getResult(1)); }; @@ -803,7 +803,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, // While (data[start] < data[childIndex]). SmallVector types(3, child.getType()); - scf::WhileOp whileOp = builder.create( + scf::WhileOp whileOp = scf::WhileOp::create(builder, loc, types, SmallVector{start, child, childIdx}); // The before-region of the WhileOp. @@ -815,7 +815,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, compareOperands[0] = start; compareOperands[1] = childIdx; Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); - builder.create(loc, cond, before->getArguments()); + scf::ConditionOp::create(builder, loc, cond, before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); @@ -827,20 +827,20 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, createSwap(builder, loc, swapOperands, xPerm, ny); start = childIdx; Value cond2 = - builder.create(loc, arith::CmpIPredicate::uge, t, child); - scf::IfOp if2 = builder.create( + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child); + scf::IfOp if2 = scf::IfOp::create(builder, loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); builder.setInsertionPointToStart(&if2.getThenRegion().front()); auto [newChild, newChildIdx] = getLargerChild(child); - builder.create(loc, ValueRange{newChild, newChildIdx}); + scf::YieldOp::create(builder, loc, ValueRange{newChild, newChildIdx}); builder.setInsertionPointToStart(&if2.getElseRegion().front()); - builder.create(loc, ValueRange{child, childIdx}); + scf::YieldOp::create(builder, loc, ValueRange{child, childIdx}); builder.setInsertionPointAfter(if2); - builder.create( + scf::YieldOp::create(builder, loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); builder.setInsertionPointAfter(ifN); - builder.create(loc); + func::ReturnOp::create(builder, loc); } /// Creates a function to perform heap sort on the values in the range of index @@ -870,45 +870,45 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; - Value n = builder.create(loc, hi, lo); + Value n = arith::SubIOp::create(builder, loc, hi, lo); // For i = (n-2)/2 downto 0. Value c0 = constantIndex(builder, loc, 0); Value c1 = constantIndex(builder, loc, 1); Value s = createSubTwoDividedByTwo(builder, loc, n); - Value up = builder.create(loc, s, c1); - scf::ForOp forI = builder.create(loc, c0, up, c1); + Value up = arith::AddIOp::create(builder, loc, s, c1); + scf::ForOp forI = scf::ForOp::create(builder, loc, c0, up, c1); builder.setInsertionPointToStart(forI.getBody()); - Value i = builder.create(loc, s, forI.getInductionVar()); - Value lopi = builder.create(loc, lo, i); + Value i = arith::SubIOp::create(builder, loc, s, forI.getInductionVar()); + Value lopi = arith::AddIOp::create(builder, loc, lo, i); SmallVector shiftDownOperands = {lo, lopi}; shiftDownOperands.append(args.begin() + xStartIdx, args.end()); shiftDownOperands.push_back(n); FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); - builder.create(loc, shiftDownFunc, TypeRange(), + func::CallOp::create(builder, loc, shiftDownFunc, TypeRange(), shiftDownOperands); builder.setInsertionPointAfter(forI); // For l = n downto 2. - up = builder.create(loc, n, c1); - scf::ForOp forL = builder.create(loc, c0, up, c1); + up = arith::SubIOp::create(builder, loc, n, c1); + scf::ForOp forL = scf::ForOp::create(builder, loc, c0, up, c1); builder.setInsertionPointToStart(forL.getBody()); - Value l = builder.create(loc, n, forL.getInductionVar()); - Value loplm1 = builder.create(loc, lo, l); - loplm1 = builder.create(loc, loplm1, c1); + Value l = arith::SubIOp::create(builder, loc, n, forL.getInductionVar()); + Value loplm1 = arith::AddIOp::create(builder, loc, lo, l); + loplm1 = arith::SubIOp::create(builder, loc, loplm1, c1); SmallVector swapOperands{lo, loplm1}; swapOperands.append(args.begin() + xStartIdx, args.end()); createSwap(builder, loc, swapOperands, xPerm, ny); shiftDownOperands[1] = lo; shiftDownOperands[shiftDownOperands.size() - 1] = - builder.create(loc, l, c1); - builder.create(loc, shiftDownFunc, TypeRange(), + arith::SubIOp::create(builder, loc, l, c1); + func::CallOp::create(builder, loc, shiftDownFunc, TypeRange(), shiftDownOperands); builder.setInsertionPointAfter(forL); - builder.create(loc); + func::ReturnOp::create(builder, loc); } /// A helper for generating code to perform quick sort. It partitions [lo, hi), @@ -933,35 +933,35 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, args.drop_back(nTrailingP)) .getResult(0); - Value lenLow = builder.create(loc, p, lo); - Value lenHigh = builder.create(loc, hi, p); + Value lenLow = arith::SubIOp::create(builder, loc, p, lo); + Value lenHigh = arith::SubIOp::create(builder, loc, hi, p); // Partition already sorts array with len <= 2 Value c2 = constantIndex(builder, loc, 2); - Value len = builder.create(loc, hi, lo); + Value len = arith::SubIOp::create(builder, loc, hi, lo); Value lenGtTwo = - builder.create(loc, arith::CmpIPredicate::ugt, len, c2); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ugt, len, c2); scf::IfOp ifLenGtTwo = - builder.create(loc, types, lenGtTwo, /*else=*/true); + scf::IfOp::create(builder, loc, types, lenGtTwo, /*else=*/true); builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); // Returns an empty range to mark the entire region is fully sorted. - builder.create(loc, ValueRange{lo, lo}); + scf::YieldOp::create(builder, loc, ValueRange{lo, lo}); // Else len > 2, need recursion. builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); - Value cond = builder.create(loc, arith::CmpIPredicate::ule, + Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule, lenLow, lenHigh); Value c0 = constantIndex(builder, loc, 0); - scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true); auto mayRecursion = [&](Value low, Value high, Value len) { Value cond = - builder.create(loc, arith::CmpIPredicate::ne, len, c0); - scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, len, c0); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond, /*else=*/false); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector operands{low, high}; operands.append(args.begin() + xStartIdx, args.end()); - builder.create(loc, func, operands); + func::CallOp::create(builder, loc, func, operands); builder.setInsertionPointAfter(ifOp); }; @@ -969,14 +969,14 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, // the bigger partition to be processed by the enclosed while-loop. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); mayRecursion(lo, p, lenLow); - builder.create(loc, ValueRange{p, hi}); + scf::YieldOp::create(builder, loc, ValueRange{p, hi}); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); mayRecursion(p, hi, lenHigh); - builder.create(loc, ValueRange{lo, p}); + scf::YieldOp::create(builder, loc, ValueRange{lo, p}); builder.setInsertionPointAfter(ifOp); - builder.create(loc, ifOp.getResults()); + scf::YieldOp::create(builder, loc, ifOp.getResults()); builder.setInsertionPointAfter(ifLenGtTwo); return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); @@ -1011,10 +1011,10 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, Value c1 = constantIndex(builder, loc, 1); Value lo = args[loIdx]; Value hi = args[hiIdx]; - Value lop1 = builder.create(loc, lo, c1); + Value lop1 = arith::AddIOp::create(builder, loc, lo, c1); // Start the outer for-stmt with induction variable i. - scf::ForOp forOpI = builder.create(loc, lop1, hi, c1); + scf::ForOp forOpI = scf::ForOp::create(builder, loc, lop1, hi, c1); builder.setInsertionPointToStart(forOpI.getBody()); Value i = forOpI.getInductionVar(); @@ -1035,24 +1035,24 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, forEachIJPairInAllBuffers( builder, loc, operands, xPerm, ny, [&](uint64_t unused, Value i, Value unused2, Value buffer) { - d.push_back(builder.create(loc, buffer, i)); + d.push_back(memref::LoadOp::create(builder, loc, buffer, i)); }); // Start the inner for-stmt with induction variable j, for moving data[p..i) // to data[p+1..i+1). - Value imp = builder.create(loc, i, p); + Value imp = arith::SubIOp::create(builder, loc, i, p); Value c0 = constantIndex(builder, loc, 0); - scf::ForOp forOpJ = builder.create(loc, c0, imp, c1); + scf::ForOp forOpJ = scf::ForOp::create(builder, loc, c0, imp, c1); builder.setInsertionPointToStart(forOpJ.getBody()); Value j = forOpJ.getInductionVar(); - Value imj = builder.create(loc, i, j); + Value imj = arith::SubIOp::create(builder, loc, i, j); operands[1] = imj; - operands[0] = builder.create(loc, imj, c1); + operands[0] = arith::SubIOp::create(builder, loc, imj, c1); forEachIJPairInAllBuffers( builder, loc, operands, xPerm, ny, [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { - Value t = builder.create(loc, buffer, imjm1); - builder.create(loc, t, buffer, imj); + Value t = memref::LoadOp::create(builder, loc, buffer, imjm1); + memref::StoreOp::create(builder, loc, t, buffer, imj); }); // Store the value at data[i] to data[p]. @@ -1061,11 +1061,11 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, forEachIJPairInAllBuffers( builder, loc, operands, xPerm, ny, [&](uint64_t k, Value p, Value usused, Value buffer) { - builder.create(loc, d[k], buffer, p); + memref::StoreOp::create(builder, loc, d[k], buffer, p); }); builder.setInsertionPointAfter(forOpI); - builder.create(loc); + func::ReturnOp::create(builder, loc); } /// Creates a function to perform quick sort or a hybrid quick sort on the @@ -1127,7 +1127,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, Value hi = args[hiIdx]; SmallVector types(2, lo.getType()); // Only two types. scf::WhileOp whileOp = - builder.create(loc, types, SmallVector{lo, hi}); + scf::WhileOp::create(builder, loc, types, SmallVector{lo, hi}); // The before-region of the WhileOp. Block *before = @@ -1136,10 +1136,10 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, lo = before->getArgument(0); hi = before->getArgument(1); Value loP1 = - builder.create(loc, lo, constantIndex(builder, loc, 1)); + arith::AddIOp::create(builder, loc, lo, constantIndex(builder, loc, 1)); Value needSort = - builder.create(loc, arith::CmpIPredicate::ult, loP1, hi); - builder.create(loc, needSort, before->getArguments()); + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, loP1, hi); + scf::ConditionOp::create(builder, loc, needSort, before->getArguments()); // The after-region of the WhileOp. Block *after = @@ -1151,53 +1151,53 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, args[1] = hi; if (isHybrid) { - Value len = builder.create(loc, hi, lo); + Value len = arith::SubIOp::create(builder, loc, hi, lo); Value lenLimit = constantIndex(builder, loc, 30); - Value lenCond = builder.create( + Value lenCond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule, len, lenLimit); scf::IfOp lenIf = - builder.create(loc, types, lenCond, /*else=*/true); + scf::IfOp::create(builder, loc, types, lenCond, /*else=*/true); // When len <= limit. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, ValueRange(args).drop_back(nTrailingP), createSortStableFunc); - builder.create(loc, insertionSortFunc, TypeRange(), + func::CallOp::create(builder, loc, insertionSortFunc, TypeRange(), ValueRange(args).drop_back(nTrailingP)); - builder.create(loc, ValueRange{lo, lo}); + scf::YieldOp::create(builder, loc, ValueRange{lo, lo}); // When len > limit. builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); Value depthLimit = args.back(); - depthLimit = builder.create(loc, depthLimit, + depthLimit = arith::SubIOp::create(builder, loc, depthLimit, constantI64(builder, loc, 1)); Value depthCond = - builder.create(loc, arith::CmpIPredicate::ule, + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule, depthLimit, constantI64(builder, loc, 0)); scf::IfOp depthIf = - builder.create(loc, types, depthCond, /*else=*/true); + scf::IfOp::create(builder, loc, types, depthCond, /*else=*/true); // When depth exceeds limit. builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); - builder.create(loc, heapSortFunc, TypeRange(), + func::CallOp::create(builder, loc, heapSortFunc, TypeRange(), ValueRange(args).drop_back(nTrailingP)); - builder.create(loc, ValueRange{lo, lo}); + scf::YieldOp::create(builder, loc, ValueRange{lo, lo}); // When depth doesn't exceed limit. builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); args.back() = depthLimit; std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); - builder.create(loc, ValueRange{lo, hi}); + scf::YieldOp::create(builder, loc, ValueRange{lo, hi}); builder.setInsertionPointAfter(depthIf); lo = depthIf.getResult(0); hi = depthIf.getResult(1); - builder.create(loc, ValueRange{lo, hi}); + scf::YieldOp::create(builder, loc, ValueRange{lo, hi}); builder.setInsertionPointAfter(lenIf); lo = lenIf.getResult(0); @@ -1208,11 +1208,11 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, } // New [lo, hi) for the next while-loop iteration. - builder.create(loc, ValueRange{lo, hi}); + scf::YieldOp::create(builder, loc, ValueRange{lo, hi}); // After the while-loop. builder.setInsertionPointAfter(whileOp); - builder.create(loc); + func::ReturnOp::create(builder, loc); } /// Implements the rewriting for operator sort and sort_coo. @@ -1228,7 +1228,7 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, if (!mtp.isDynamicDim(0)) { auto newMtp = MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); - v = rewriter.create(loc, newMtp, v); + v = memref::CastOp::create(rewriter, loc, newMtp, v); } operands.push_back(v); } @@ -1248,12 +1248,12 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, // As a heuristics, set depthLimit = 2 * log2(n). Value lo = operands[loIdx]; Value hi = operands[hiIdx]; - Value len = rewriter.create( + Value len = arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), - rewriter.create(loc, hi, lo)); - Value depthLimit = rewriter.create( + arith::SubIOp::create(rewriter, loc, hi, lo)); + Value depthLimit = arith::SubIOp::create(rewriter, loc, constantI64(rewriter, loc, 64), - rewriter.create(loc, len)); + math::CountLeadingZerosOp::create(rewriter, loc, len)); operands.push_back(depthLimit); break; } @@ -1307,33 +1307,33 @@ struct PushBackRewriter : OpRewritePattern { Location loc = op->getLoc(); Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); - Value capacity = rewriter.create(loc, buffer, c0); + Value capacity = memref::DimOp::create(rewriter, loc, buffer, c0); Value size = op.getCurSize(); Value value = op.getValue(); Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); - Value newSize = rewriter.create(loc, size, n); + Value newSize = arith::AddIOp::create(rewriter, loc, size, n); auto nValue = dyn_cast_or_null(n.getDefiningOp()); bool nIsOne = (nValue && nValue.value() == 1); if (!op.getInbounds()) { - Value cond = rewriter.create( + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ugt, newSize, capacity); Value c2 = constantIndex(rewriter, loc, 2); auto bufferType = MemRefType::get({ShapedType::kDynamic}, value.getType()); - scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, bufferType, cond, /*else=*/true); // True branch. rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); if (nIsOne) { - capacity = rewriter.create(loc, capacity, c2); + capacity = arith::MulIOp::create(rewriter, loc, capacity, c2); } else { // Use a do-while loop to calculate the new capacity as follows: // do { new_capacity *= 2 } while (size > new_capacity) scf::WhileOp whileOp = - rewriter.create(loc, capacity.getType(), capacity); + scf::WhileOp::create(rewriter, loc, capacity.getType(), capacity); // The before-region of the WhileOp. Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, @@ -1341,36 +1341,36 @@ struct PushBackRewriter : OpRewritePattern { rewriter.setInsertionPointToEnd(before); capacity = - rewriter.create(loc, before->getArgument(0), c2); - cond = rewriter.create(loc, arith::CmpIPredicate::ugt, + arith::MulIOp::create(rewriter, loc, before->getArgument(0), c2); + cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ugt, newSize, capacity); - rewriter.create(loc, cond, ValueRange{capacity}); + scf::ConditionOp::create(rewriter, loc, cond, ValueRange{capacity}); // The after-region of the WhileOp. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, {capacity.getType()}, {loc}); rewriter.setInsertionPointToEnd(after); - rewriter.create(loc, after->getArguments()); + scf::YieldOp::create(rewriter, loc, after->getArguments()); rewriter.setInsertionPointAfter(whileOp); capacity = whileOp.getResult(0); } Value newBuffer = - rewriter.create(loc, bufferType, buffer, capacity); + memref::ReallocOp::create(rewriter, loc, bufferType, buffer, capacity); if (enableBufferInitialization) { - Value fillSize = rewriter.create(loc, capacity, newSize); + Value fillSize = arith::SubIOp::create(rewriter, loc, capacity, newSize); Value fillValue = constantZero(rewriter, loc, value.getType()); - Value subBuffer = rewriter.create( + Value subBuffer = memref::SubViewOp::create(rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize}, /*size=*/ValueRange{fillSize}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); - rewriter.create(loc, fillValue, subBuffer); + linalg::FillOp::create(rewriter, loc, fillValue, subBuffer); } - rewriter.create(loc, newBuffer); + scf::YieldOp::create(rewriter, loc, newBuffer); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, buffer); + scf::YieldOp::create(rewriter, loc, buffer); // Prepare for adding the value to the end of the buffer. rewriter.setInsertionPointAfter(ifOp); @@ -1379,12 +1379,12 @@ struct PushBackRewriter : OpRewritePattern { // Add the value to the end of the buffer. if (nIsOne) { - rewriter.create(loc, value, buffer, size); + memref::StoreOp::create(rewriter, loc, value, buffer, size); } else { - Value subBuffer = rewriter.create( + Value subBuffer = memref::SubViewOp::create(rewriter, loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); - rewriter.create(loc, value, subBuffer); + linalg::FillOp::create(rewriter, loc, value, subBuffer); } // Update the buffer size. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index e89b34d457ff8..001754b5531be 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -59,7 +59,7 @@ static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { return op; // existing markAsGPUContainer(topModule); builder.setInsertionPointToStart(topModule.getBody()); - return builder.create(topModule->getLoc(), + return gpu::GPUModuleOp::create(builder, topModule->getLoc(), "sparse_kernels"); } @@ -81,7 +81,7 @@ static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, argsTp.push_back(arg.getType()); FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); auto gpuFunc = - builder.create(gpuModule->getLoc(), kernelName, type); + gpu::GPUFuncOp::create(builder, gpuModule->getLoc(), kernelName, type); gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); return gpuFunc; @@ -115,28 +115,28 @@ static Value genHostRegisterMemref(OpBuilder &builder, Location loc, MemRefType memTp = cast(mem.getType()); UnrankedMemRefType resTp = UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); - Value cast = builder.create(loc, resTp, mem); - builder.create(loc, cast); + Value cast = memref::CastOp::create(builder, loc, resTp, mem); + gpu::HostRegisterOp::create(builder, loc, cast); return cast; } /// Unmaps the provided buffer, expecting the casted buffer. static void genHostUnregisterMemref(OpBuilder &builder, Location loc, Value cast) { - builder.create(loc, cast); + gpu::HostUnregisterOp::create(builder, loc, cast); } /// Generates first wait in an asynchronous chain. static Value genFirstWait(OpBuilder &builder, Location loc) { Type tokenType = builder.getType(); - return builder.create(loc, tokenType, ValueRange()) + return gpu::WaitOp::create(builder, loc, tokenType, ValueRange()) .getAsyncToken(); } /// Generates last, blocking wait in an asynchronous chain. static void genBlockingWait(OpBuilder &builder, Location loc, ValueRange operands) { - builder.create(loc, Type(), operands); + gpu::WaitOp::create(builder, loc, Type(), operands); } /// Allocates memory on the device. @@ -156,7 +156,7 @@ static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, dynamicSizes.push_back(dimOp); } } - return builder.create(loc, TypeRange({memTp, token.getType()}), + return gpu::AllocOp::create(builder, loc, TypeRange({memTp, token.getType()}), token, dynamicSizes, ValueRange()); } @@ -164,14 +164,14 @@ static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, Value size) { const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); - return builder.create(loc, memTp, size).getResult(); + return memref::AllocOp::create(builder, loc, memTp, size).getResult(); } // Allocates a typed buffer on the device with given size. static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, Value size, Value token) { const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); - return builder.create(loc, TypeRange({memTp, token.getType()}), + return gpu::AllocOp::create(builder, loc, TypeRange({memTp, token.getType()}), token, size, ValueRange()); } @@ -184,14 +184,14 @@ static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, /// Deallocates memory from the device. static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { - return builder.create(loc, token.getType(), token, mem) + return gpu::DeallocOp::create(builder, loc, token.getType(), token, mem) .getAsyncToken(); } /// Copies memory between host and device (direction is implicit). static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, Value src, Value token) { - return builder.create(loc, token.getType(), token, dst, src) + return gpu::MemcpyOp::create(builder, loc, token.getType(), token, dst, src) .getAsyncToken(); } @@ -212,7 +212,7 @@ static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, auto tensorType = llvm::cast(tensor.getType()); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - return rewriter.create(loc, memrefType, tensor); + return bufferization::ToBufferOp::create(rewriter, loc, memrefType, tensor); } /// Prepares the outlined arguments, passing scalars and buffers in. Here we @@ -293,13 +293,13 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, // so that: // row = blockIdx.x * blockDim.x + threadIdx.x // inc = blockDim.x * gridDim.x - Value bid = rewriter.create(loc, gpu::Dimension::x); - Value bsz = rewriter.create(loc, gpu::Dimension::x); - Value tid = rewriter.create(loc, gpu::Dimension::x); - Value gsz = rewriter.create(loc, gpu::Dimension::x); - Value mul = rewriter.create(loc, bid, bsz); - Value row = rewriter.create(loc, mul, tid); - Value inc = rewriter.create(loc, bsz, gsz); + Value bid = gpu::BlockIdOp::create(rewriter, loc, gpu::Dimension::x); + Value bsz = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x); + Value tid = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); + Value gsz = gpu::GridDimOp::create(rewriter, loc, gpu::Dimension::x); + Value mul = arith::MulIOp::create(rewriter, loc, bid, bsz); + Value row = arith::AddIOp::create(rewriter, loc, mul, tid); + Value inc = arith::MulIOp::create(rewriter, loc, bsz, gsz); // Construct the iteration over the computational space that // accounts for the fact that the total number of threads and @@ -308,7 +308,7 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, // // } Value upper = irMap.lookup(forallOp.getUpperBound()[0]); - scf::ForOp forOp = rewriter.create(loc, row, upper, inc); + scf::ForOp forOp = scf::ForOp::create(rewriter, loc, row, upper, inc); // The scf.for builder creates an empty block. scf.for does not allow multiple // blocks in its region, so delete the block before `cloneRegionBefore` adds // an additional block. @@ -321,7 +321,7 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, // Done. rewriter.setInsertionPointAfter(forOp); - rewriter.create(gpuFunc->getLoc()); + gpu::ReturnOp::create(rewriter, gpuFunc->getLoc()); } //===----------------------------------------------------------------------===// @@ -496,11 +496,11 @@ static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, if (format == CuSparseFormat::kCOO) { // Library uses SoA COO, direct IR uses AoS COO. if (enableRT) - return builder.create(loc, a, 0); - return builder.create(loc, a); + return ToCoordinatesOp::create(builder, loc, a, 0); + return ToCoordinatesBufferOp::create(builder, loc, a); } // Formats CSR/CSC and BSR use positions at 1. - return builder.create(loc, a, 1); + return ToPositionsOp::create(builder, loc, a, 1); } /// Generates the second coordinates of a sparse matrix. @@ -510,7 +510,7 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, if (isCOO && !enableRT) return Value(); // nothing needed // Formats CSR/CSC and BSR use coordinates at 1. - return builder.create(loc, a, 1); + return ToCoordinatesOp::create(builder, loc, a, 1); } /// Generates the sparse matrix handle. @@ -523,12 +523,12 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, // Library uses SoA COO, direct IR uses AoS COO. if (enableRT) { assert(colA); - return builder.create(loc, handleTp, tokenTp, token, + return gpu::CreateCooOp::create(builder, loc, handleTp, tokenTp, token, sz1, sz2, nseA, rowA, colA, valA); } #ifdef CUSPARSE_COO_AOS assert(!colA); - return builder.create(loc, handleTp, tokenTp, token, + return gpu::CreateCooAoSOp::create(builder, loc, handleTp, tokenTp, token, sz1, sz2, nseA, rowA, valA); #else llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); @@ -536,10 +536,10 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, } assert(colA); if (format == CuSparseFormat::kCSR) - return builder.create(loc, handleTp, tokenTp, token, sz1, + return gpu::CreateCsrOp::create(builder, loc, handleTp, tokenTp, token, sz1, sz2, nseA, rowA, colA, valA); if (format == CuSparseFormat::kCSC) - return builder.create(loc, handleTp, tokenTp, token, sz1, + return gpu::CreateCscOp::create(builder, loc, handleTp, tokenTp, token, sz1, sz2, nseA, rowA, colA, valA); // BSR requires a bit more work since we need to pass in the block size // and all others sizes in terms of blocks (#block-rows, #block-cols, @@ -549,11 +549,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, assert(dims.size() == 2 && dims[0] == dims[1]); uint64_t b = dims[0]; Value bSz = constantIndex(builder, loc, b); - Value bRows = builder.create(loc, sz1, bSz); - Value bCols = builder.create(loc, sz2, bSz); - Value bNum = builder.create( + Value bRows = arith::DivUIOp::create(builder, loc, sz1, bSz); + Value bCols = arith::DivUIOp::create(builder, loc, sz2, bSz); + Value bNum = arith::DivUIOp::create(builder, loc, nseA, constantIndex(builder, loc, b * b)); - return builder.create(loc, handleTp, tokenTp, token, bRows, + return gpu::CreateBsrOp::create(builder, loc, handleTp, tokenTp, token, bRows, bCols, bNum, bSz, bSz, rowA, colA, valA); } @@ -579,12 +579,12 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter, // a : memR/memC/memV -> rowA,colA,valA // x : memX -> vecX // y : memY -> vecY - Value nseA = rewriter.create(loc, a); + Value nseA = NumberOfEntriesOp::create(rewriter, loc, a); Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty - Value memV = rewriter.create(loc, a); + Value memV = ToValuesOp::create(rewriter, loc, a); Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); @@ -606,18 +606,18 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter, nseA, rowA, colA, valA, format, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dvecX = rewriter.create( + auto dvecX = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp, token, vecX, szX); Value dnX = dvecX.getResult(0); token = dvecX.getAsyncToken(); - auto dvecY = rewriter.create( + auto dvecY = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp, token, vecY, szY); Value dnY = dvecY.getResult(0); token = dvecY.getAsyncToken(); auto dnYType = llvm::cast(y.getType()).getElementType(); // Precompute buffersize for SpMV. - auto bufferComp = rewriter.create( + auto bufferComp = gpu::SpMVBufferSizeOp::create(rewriter, loc, indexTp, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType); Value bufferSz = bufferComp.getResult(0); @@ -627,16 +627,16 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter, token = buf.getAsyncToken(); // Perform the SpMV. - auto spmvComp = rewriter.create( + auto spmvComp = gpu::SpMVOp::create(rewriter, loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); token = spmvComp.getAsyncToken(); // Copy data back to host and free all the resoures. - token = rewriter.create(loc, tokenTp, token, spMatA) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnX) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnX) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnY) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnY) .getAsyncToken(); token = genDeallocMemRef(rewriter, loc, rowA, token); if (colA) @@ -676,13 +676,13 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter, // a : memR/memC/memV -> rowA,colA,valA // b : bufB -> matB // c : bufC -> matC - Value nseA = rewriter.create(loc, a); + Value nseA = NumberOfEntriesOp::create(rewriter, loc, a); Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty - Value memV = rewriter.create(loc, a); + Value memV = ToValuesOp::create(rewriter, loc, a); Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); @@ -704,12 +704,12 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter, nseA, rowA, colA, valA, format, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dmatB = rewriter.create( + auto dmatB = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp, token, matB, SmallVector{szk, szn}); Value dnB = dmatB.getResult(0); token = dmatB.getAsyncToken(); - auto dmatC = rewriter.create( + auto dmatC = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp, token, matC, SmallVector{szm, szn}); Value dnC = dmatC.getResult(0); @@ -717,7 +717,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter, auto dmatCType = llvm::cast(c.getType()).getElementType(); // Precompute buffersize for SpMM. - auto bufferComp = rewriter.create( + auto bufferComp = gpu::SpMMBufferSizeOp::create(rewriter, loc, indexTp, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dmatCType); Value bufferSz = bufferComp.getResult(0); @@ -728,16 +728,16 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter, auto dnCType = llvm::cast(c.getType()).getElementType(); // Perform the SpMM. - auto spmmComp = rewriter.create( + auto spmmComp = gpu::SpMMOp::create(rewriter, loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); token = spmmComp.getAsyncToken(); // Copy data back to host and free all the resoures. - token = rewriter.create(loc, tokenTp, token, spMatA) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnB) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnC) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC) .getAsyncToken(); token = genDeallocMemRef(rewriter, loc, rowA, token); if (colA) @@ -778,17 +778,17 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, // b : bmemR/bmemC/bmemV -> rowB,colB,valB // c : materializes auto dnCType = cTp.getElementType(); - Value nseA = rewriter.create(loc, a); - Value nseB = rewriter.create(loc, b); + Value nseA = NumberOfEntriesOp::create(rewriter, loc, a); + Value nseB = NumberOfEntriesOp::create(rewriter, loc, b); Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty - Value amemV = rewriter.create(loc, a); + Value amemV = ToValuesOp::create(rewriter, loc, a); Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty - Value bmemV = rewriter.create(loc, b); + Value bmemV = ToValuesOp::create(rewriter, loc, b); Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); Value colA = genAllocCopy(rewriter, loc, amemC, tokens); Value valA = genAllocCopy(rewriter, loc, amemV, tokens); @@ -818,7 +818,7 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, // Sparse matrix C materializes (also assumes beta == 0). Value zero = constantIndex(rewriter, loc, 0); Value one = constantIndex(rewriter, loc, 1); - Value mplus1 = rewriter.create(loc, szm, one); + Value mplus1 = arith::AddIOp::create(rewriter, loc, szm, one); auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); Value rowC = e1.getResult(0); token = e1.getAsyncToken(); @@ -836,10 +836,10 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, // Precompute buffersizes for SpGEMM. Operation *descOp = - rewriter.create(loc, descTp, tokenTp, token); + gpu::SpGEMMCreateDescrOp::create(rewriter, loc, descTp, tokenTp, token); Value desc = descOp->getResult(0); token = descOp->getResult(1); - Operation *work1 = rewriter.create( + Operation *work1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(rewriter, loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); @@ -848,7 +848,7 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); Value buffer1 = buf1.getResult(0); token = buf1.getAsyncToken(); - Operation *work2 = rewriter.create( + Operation *work2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(rewriter, loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, bufferSz1, buffer1, @@ -856,7 +856,7 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, token = work2->getResult(1); // Compute step. - Operation *compute1 = rewriter.create( + Operation *compute1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(rewriter, loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); @@ -865,14 +865,14 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); Value buffer2 = buf2.getResult(0); token = buf2.getAsyncToken(); - Operation *compute2 = rewriter.create( + Operation *compute2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(rewriter, loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); token = compute2->getResult(1); // Get sizes. - Operation *sizes = rewriter.create( + Operation *sizes = gpu::SpMatGetSizeOp::create(rewriter, loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); Value nnz = sizes->getResult(2); token = sizes->getResult(3); @@ -884,10 +884,10 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, token = a3.getAsyncToken(); // Update C with new pointers and copy final product back into C. - Operation *update = rewriter.create( + Operation *update = gpu::SetCsrPointersOp::create(rewriter, loc, tokenTp, token, spMatC, rowC, colC, valC); token = update->getResult(0); - Operation *copy = rewriter.create( + Operation *copy = gpu::SpGEMMCopyOp::create(rewriter, loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); token = copy->getResult(0); @@ -898,13 +898,13 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); // Copy data back to host and free all the resoures. - token = rewriter.create(loc, tokenTp, token, desc) + token = gpu::SpGEMMDestroyDescrOp::create(rewriter, loc, tokenTp, token, desc) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, spMatA) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, spMatB) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatB) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, spMatC) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC) .getAsyncToken(); token = genCopyMemRef(rewriter, loc, rowH, rowC, token); token = genCopyMemRef(rewriter, loc, colH, colC, token); @@ -925,11 +925,11 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, tokens.clear(); // Done. - Value vt = rewriter.create( + Value vt = bufferization::ToTensorOp::create(rewriter, loc, memref::getTensorTypeFromMemRefType(valH.getType()), valH); - Value rt = rewriter.create( + Value rt = bufferization::ToTensorOp::create(rewriter, loc, memref::getTensorTypeFromMemRefType(rowH.getType()), rowH); - Value ct = rewriter.create( + Value ct = bufferization::ToTensorOp::create(rewriter, loc, memref::getTensorTypeFromMemRefType(colH.getType()), colH); rewriter.replaceOpWithNewOp(op, c.getType(), ValueRange{rt, ct}, vt); @@ -980,17 +980,17 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, Type spMatHandleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); - Operation *spGenA = rewriter.create( + Operation *spGenA = gpu::Create2To4SpMatOp::create(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dmatB = rewriter.create( + auto dmatB = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp, token, matB, SmallVector{szk, szn}); Value dnB = dmatB.getResult(0); token = dmatB.getAsyncToken(); - auto dmatC = rewriter.create( + auto dmatC = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp, token, matC, SmallVector{szm, szn}); Value dnC = dmatC.getResult(0); @@ -1000,7 +1000,7 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, // Precompute buffersize for SpMM. SmallVector bufferTypes_{indexTp, indexTp, indexTp}; TypeRange bufferTypes(bufferTypes_); - auto bufferComp = rewriter.create( + auto bufferComp = gpu::SpMMBufferSizeOp::create(rewriter, loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, /*computeType=*/dmatCType); @@ -1022,17 +1022,17 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, // Perform the SpMM. auto dnCType = llvm::cast(matC.getType()).getElementType(); - auto spmmComp = rewriter.create( + auto spmmComp = gpu::SpMMOp::create(rewriter, loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, SmallVector{buffer1, buffer2, buffer3}); token = spmmComp.getAsyncToken(); // Copy data back to host and free all the resources. - token = rewriter.create(loc, tokenTp, token, spMatA) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnB) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnC) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC) .getAsyncToken(); token = genDeallocMemRef(rewriter, loc, buffer1, token); token = genDeallocMemRef(rewriter, loc, buffer2, token); @@ -1073,7 +1073,7 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, // a : bufA -> matA // b : bufB -> matB // c : memR/memC/memV -> rowC,colC,valC - Value nseC = rewriter.create(loc, c); + Value nseC = NumberOfEntriesOp::create(rewriter, loc, c); Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); @@ -1083,7 +1083,7 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, Value matB = genAllocCopy(rewriter, loc, bufB, tokens); Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty - Value memV = rewriter.create(loc, c); + Value memV = ToValuesOp::create(rewriter, loc, c); Value rowC = genAllocCopy(rewriter, loc, memR, tokens); Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valC = genAllocCopy(rewriter, loc, memV, tokens); @@ -1096,11 +1096,11 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, Type spMatHandleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); - auto dmatA = rewriter.create( + auto dmatA = gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp, token, matA, SmallVector{szm, szk}); Value dnA = dmatA.getResult(0); token = dmatA.getAsyncToken(); - auto dmatB = rewriter.create( + auto dmatB = gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp, token, matB, SmallVector{szk, szn}); Value dnB = dmatB.getResult(0); token = dmatB.getAsyncToken(); @@ -1112,7 +1112,7 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, auto dnCType = llvm::cast(c.getType()).getElementType(); // Precompute buffersize for SDDMM. - auto bufferComp = rewriter.create( + auto bufferComp = gpu::SDDMMBufferSizeOp::create(rewriter, loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); Value bufferSz = bufferComp.getResult(0); token = bufferComp.getAsyncToken(); @@ -1121,16 +1121,16 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, token = buf.getAsyncToken(); // Perform the SDDMM. - auto sddmmComp = rewriter.create(loc, tokenTp, token, dnA, dnB, + auto sddmmComp = gpu::SDDMMOp::create(rewriter, loc, tokenTp, token, dnA, dnB, spMatC, dnCType, buffer); token = sddmmComp.getAsyncToken(); // Copy data back to host and free all the resoures. - token = rewriter.create(loc, tokenTp, token, dnA) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnA) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, dnB) + token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB) .getAsyncToken(); - token = rewriter.create(loc, tokenTp, token, spMatC) + token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC) .getAsyncToken(); token = genDeallocMemRef(rewriter, loc, buffer, token); token = genDeallocMemRef(rewriter, loc, matA, token); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index 2f68008e68b5f..0986a0f2ea0f0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -67,11 +67,11 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber()); for (unsigned i : caseBits.bits()) { SparseIterator *it = iters[i].get(); - Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, + Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); - casePred = rewriter.create(loc, casePred, pred); + casePred = arith::AndIOp::create(rewriter, loc, casePred, pred); } - scf::IfOp ifOp = rewriter.create( + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); @@ -103,7 +103,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, ValueRange yields = spY.getResults(); rewriter.eraseOp(spY); rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front()); - rewriter.create(loc, yields); + scf::YieldOp::create(rewriter, loc, yields); // Generates remaining case recursively. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); @@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, newBlocks.drop_front(), oldBlocks.drop_front(), userReduc); if (!res.empty()) - rewriter.create(loc, res); + scf::YieldOp::create(rewriter, loc, res); rewriter.setInsertionPointAfter(ifOp); return ifOp.getResults(); @@ -127,7 +127,7 @@ static ValueRange genLoopWithIterator( if (it->iteratableByFor()) { auto [lo, hi] = it->genForCond(rewriter, loc); Value step = constantIndex(rewriter, loc, 1); - scf::ForOp forOp = rewriter.create( + scf::ForOp forOp = scf::ForOp::create(rewriter, loc, lo, hi, step, reduc, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { // Empty builder function to ensure that no terminator is created. @@ -140,7 +140,7 @@ static ValueRange genLoopWithIterator( it, forOp.getRegionIterArgs()); rewriter.setInsertionPointToEnd(forOp.getBody()); - rewriter.create(loc, ret); + scf::YieldOp::create(rewriter, loc, ret); } return forOp.getResults(); } @@ -149,7 +149,7 @@ static ValueRange genLoopWithIterator( llvm::append_range(ivs, it->getCursor()); TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = rewriter.create(loc, types, ivs); + auto whileOp = scf::WhileOp::create(rewriter, loc, types, ivs); { OpBuilder::InsertionGuard guard(rewriter); // Generates loop conditions. @@ -158,7 +158,7 @@ static ValueRange genLoopWithIterator( rewriter.setInsertionPointToStart(before); ValueRange bArgs = before->getArguments(); auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); - rewriter.create(loc, whileCond, before->getArguments()); + scf::ConditionOp::create(rewriter, loc, whileCond, before->getArguments()); // Delegates loop body generation. Region &dstRegion = whileOp.getAfter(); @@ -175,7 +175,7 @@ static ValueRange genLoopWithIterator( SmallVector yields; llvm::append_range(yields, ret); llvm::append_range(yields, it->forward(rewriter, loc)); - rewriter.create(loc, yields); + scf::YieldOp::create(rewriter, loc, yields); } return whileOp.getResults().drop_front(it->getCursor().size()); } @@ -212,7 +212,7 @@ class ExtractValOpConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value pos = adaptor.getIterator().back(); - Value valBuf = rewriter.create( + Value valBuf = ToValuesOp::create(rewriter, loc, llvm::getSingleElement(adaptor.getTensor())); rewriter.replaceOpWithNewOp(op, valBuf, pos); return success(); @@ -385,12 +385,12 @@ class SparseCoIterateOpConverter : public OpConversionPattern { SmallVector nextIterYields(res); // 2nd. foward the loop. for (SparseIterator *it : validIters) { - Value cmp = rewriter.create( + Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); it->forwardIf(rewriter, loc, cmp); llvm::append_range(nextIterYields, it->getCursor()); } - rewriter.create(loc, nextIterYields); + scf::YieldOp::create(rewriter, loc, nextIterYields); // Exit the loop, relink the iterator SSA value. rewriter.setInsertionPointAfter(loop); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index 0473e058646b9..fd1bbb300b6a6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -44,7 +44,7 @@ struct DemapInsRewriter : public OpRewritePattern { SmallVector deMappedIns(op->getOperands()); for (Value &in : deMappedIns) { if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) { - in = rewriter.create(loc, stt->getDemappedType(), in); + in = ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in); changed = true; } } @@ -338,14 +338,14 @@ translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { // Generates a "de"mapping reinterpretation of the map. static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) { - return builder.create(val.getLoc(), enc.withoutDimToLvl(), + return ReinterpretMapOp::create(builder, val.getLoc(), enc.withoutDimToLvl(), val); } // Generates a "re"mapping reinterpretation of the map. static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) { - return builder.create(val.getLoc(), enc, val); + return ReinterpretMapOp::create(builder, val.getLoc(), enc, val); } static SmallVector remapValueRange(OpBuilder &rewriter, TypeRange types, @@ -354,7 +354,7 @@ static SmallVector remapValueRange(OpBuilder &rewriter, TypeRange types, assert(outs.size() == types.size()); for (auto [r, t] : llvm::zip(ret, types)) if (r.getType() != t) - r = rewriter.create(r.getLoc(), t, r); + r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r); return ret; } @@ -567,7 +567,7 @@ struct GenericOpScheduler : public OpRewritePattern { // Inserting the transpose rewriter.setInsertionPoint(linalgOp); RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType(); - Value dst = rewriter.create(tval.getLoc(), dstTp, tval); + Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval); rewriter.modifyOpInPlace(linalgOp, [&]() { linalgOp->setOperand(t->getOperandNumber(), dst); }); @@ -575,7 +575,7 @@ struct GenericOpScheduler : public OpRewritePattern { // Release the transposed form afterwards. // TODO: CSE when used in more than one following op? rewriter.setInsertionPointAfter(linalgOp); - rewriter.create(dst.getLoc(), dst); + bufferization::DeallocTensorOp::create(rewriter, dst.getLoc(), dst); return success(); } @@ -605,7 +605,7 @@ struct TensorAllocDemapper : public OpRewritePattern { ValueRange dynSz = op.getDynamicSizes(); for (int64_t dimSz : stt.getDimShape()) { if (ShapedType::isDynamic(dimSz)) { - Value maxCrd = rewriter.create( + Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(), constantIndex(rewriter, loc, 1)); maxDimCrds.push_back(maxCrd); dynSz = dynSz.drop_front(); @@ -620,7 +620,7 @@ struct TensorAllocDemapper : public OpRewritePattern { SmallVector dynLvlSzs; for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { if (ShapedType::isDynamic(lvlShape[i])) { - Value sz = rewriter.create( + Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); dynLvlSzs.push_back(sz); } @@ -651,7 +651,7 @@ struct TensorInsertDemapper auto stt = getSparseTensorType(op.getResult()); ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), CrdTransDirectionKind::dim2lvl); - auto insertOp = rewriter.create( + auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(), adaptor.getDest(), lvlCrd); Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); @@ -766,7 +766,7 @@ struct ForeachOpDemapper stt && !stt->isIdentity()) { Value y = genDemap(rewriter, stt->getEncoding(), yield.getSingleResult()); - rewriter.create(loc, y); + YieldOp::create(rewriter, loc, y); rewriter.eraseOp(yield); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp index f85c4761a8d52..7b662c2fb0db0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp @@ -126,7 +126,7 @@ void collapseSparseSpace(MutableArrayRef toCollapse) { OpBuilder builder(root); // Construct the collapsed iteration space. - auto collapsedSpace = builder.create( + auto collapsedSpace = ExtractIterSpaceOp::create(builder, loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), leaf.getHiLvl()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp index 01028f71c20bb..a1b9035f3b6e6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -69,13 +69,13 @@ class SpecifierStructBuilder : public StructBuilder { Value extractField(OpBuilder &builder, Location loc, ArrayRef indices) const { return genCast(builder, loc, - builder.create(loc, value, indices), + LLVM::ExtractValueOp::create(builder, loc, value, indices), builder.getIndexType()); } void insertField(OpBuilder &builder, Location loc, ArrayRef indices, Value v) { - value = builder.create( + value = LLVM::InsertValueOp::create(builder, loc, value, genCast(builder, loc, v, builder.getIntegerType(64)), indices); } @@ -110,7 +110,7 @@ class SpecifierStructBuilder : public StructBuilder { Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, Type structType, Value source) { - Value metaData = builder.create(loc, structType); + Value metaData = LLVM::PoisonOp::create(builder, loc, structType); SpecifierStructBuilder md(metaData); if (!source) { auto memSizeArrayType = @@ -204,14 +204,14 @@ void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, /// Builds IR extracting the memory size array from the descriptor. Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, Location loc) const { - return builder.create(loc, value, + return LLVM::ExtractValueOp::create(builder, loc, value, kMemSizePosInSpecifier); } /// Builds IR inserting the memory size array into the descriptor. void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, Value array) { - value = builder.create(loc, value, array, + value = LLVM::InsertValueOp::create(builder, loc, value, array, kMemSizePosInSpecifier); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 001ea62b07360..0450e1d6e8775 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -50,7 +50,7 @@ static SmallVector flattenValues(ArrayRef values) { /// Generates a load with proper `index` typing. static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { idx = genCast(builder, loc, idx, builder.getIndexType()); - return builder.create(loc, mem, idx); + return memref::LoadOp::create(builder, loc, mem, idx); } /// Generates a store with proper `index` typing and proper value. @@ -59,7 +59,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, idx = genCast(builder, loc, idx, builder.getIndexType()); val = genCast(builder, loc, val, cast(mem.getType()).getElementType()); - builder.create(loc, val, mem, idx); + memref::StoreOp::create(builder, loc, val, mem, idx); } /// Creates a straightforward counting for-loop. @@ -70,7 +70,7 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, if (!lower) lower = constantZero(builder, loc, indexType); Value one = constantOne(builder, loc, indexType); - scf::ForOp forOp = builder.create(loc, lower, upper, one, fields); + scf::ForOp forOp = scf::ForOp::create(builder, loc, lower, upper, one, fields); for (unsigned i = 0, e = fields.size(); i < e; i++) fields[i] = forOp.getRegionIterArg(i); builder.setInsertionPointToStart(forOp.getBody()); @@ -86,7 +86,7 @@ static void createPushback(OpBuilder &builder, Location loc, Value field = desc.getMemRefField(kind, lvl); StorageSpecifierKind specFieldKind = toSpecifierKind(kind); - auto pushBackOp = builder.create( + auto pushBackOp = PushBackOp::create(builder, loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, genCast(builder, loc, value, etp), repeat); @@ -112,7 +112,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, Value posZero = constantZero(builder, loc, stt.getPosType()); if (isLooseCompressedLT(lt)) { Value two = constantIndex(builder, loc, 2); - linear = builder.create(loc, linear, two); + linear = arith::MulIOp::create(builder, loc, linear, two); } createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl, /*value=*/posZero, /*repeat=*/linear); @@ -125,7 +125,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, // otherwise the values array for the from-here "all-dense" case. assert(isDenseLT(lt)); Value size = desc.getLvlSize(builder, loc, lvl); - linear = builder.create(loc, linear, size); + linear = arith::MulIOp::create(builder, loc, linear, size); } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, stt.getElementType()); @@ -137,11 +137,11 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit) { - Value buffer = builder.create(loc, memRefType, sz); + Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz); Type elemType = memRefType.getElementType(); if (enableInit) { Value fillValue = constantZero(builder, loc, elemType); - builder.create(loc, fillValue, buffer); + linalg::FillOp::create(builder, loc, fillValue, buffer); } return buffer; } @@ -179,14 +179,14 @@ static void createAllocFields(OpBuilder &builder, Location loc, valHeuristic = lvlSizesValues[0]; for (Level lvl = 1; lvl < lvlRank; lvl++) valHeuristic = - builder.create(loc, valHeuristic, lvlSizesValues[lvl]); + arith::MulIOp::create(builder, loc, valHeuristic, lvlSizesValues[lvl]); } else if (sizeHint) { if (stt.getAoSCOOStart() == 0) { posHeuristic = constantIndex(builder, loc, 2); - crdHeuristic = builder.create( + crdHeuristic = arith::MulIOp::create(builder, loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) { - posHeuristic = builder.create( + posHeuristic = arith::AddIOp::create(builder, loc, sizeHint, constantIndex(builder, loc, 1)); crdHeuristic = sizeHint; } else { @@ -280,7 +280,7 @@ static Value genCompressed(OpBuilder &builder, Location loc, unsigned crdStride; std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl); const Value one = constantIndex(builder, loc, 1); - const Value pp1 = builder.create(loc, parentPos, one); + const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one); const Value positionsAtLvl = desc.getPosMemRef(lvl); const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos); const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1); @@ -288,29 +288,29 @@ static Value genCompressed(OpBuilder &builder, Location loc, const Value crdStrideC = crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value(); const Value msz = - crdStrideC ? builder.create(loc, crdMsz, crdStrideC) + crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC) : crdMsz; - const Value plast = builder.create( + const Value plast = arith::SubIOp::create(builder, loc, genCast(builder, loc, pstop, indexType), one); // Conditional expression. - Value lt = builder.create(loc, arith::CmpIPredicate::ult, + Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, pstart, pstop); types.push_back(boolType); - scf::IfOp ifOp1 = builder.create(loc, types, lt, /*else*/ true); + scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt, /*else*/ true); types.pop_back(); builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); Value crd = genLoad(builder, loc, desc.getMemRefField(crdFidx), - crdStrideC ? builder.create(loc, plast, crdStrideC) + crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC) : plast); - Value eq = builder.create( + Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), lvlCoords[lvl]); - builder.create(loc, eq); + scf::YieldOp::create(builder, loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (lvl > 0) genStore(builder, loc, msz, positionsAtLvl, parentPos); - builder.create(loc, constantI1(builder, loc, false)); + scf::YieldOp::create(builder, loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); // If present construct. Note that for a non-unique dimension level, we // simply set the condition to false and rely on CSE/DCE to clean up the IR. @@ -322,19 +322,19 @@ static Value genCompressed(OpBuilder &builder, Location loc, types.push_back(indexType); const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0) : constantI1(builder, loc, false); - scf::IfOp ifOp2 = builder.create(loc, types, p, /*else*/ true); + scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p, /*else*/ true); // If present (fields unaffected, update pnext to plast). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); // FIXME: This does not looks like a clean way, but probably the most // efficient way. desc.getFields().push_back(plast); - builder.create(loc, desc.getFields()); + scf::YieldOp::create(builder, loc, desc.getFields()); desc.getFields().pop_back(); // If !present (changes fields, update pnext). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); - Value mszp1 = builder.create(loc, msz, one); + Value mszp1 = arith::AddIOp::create(builder, loc, msz, one); genStore(builder, loc, mszp1, positionsAtLvl, pp1); createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl, /*value=*/lvlCoords[lvl]); @@ -343,7 +343,7 @@ static Value genCompressed(OpBuilder &builder, Location loc, allocSchemeForRank(builder, loc, desc, lvl + 1); desc.getFields().push_back(msz); - builder.create(loc, desc.getFields()); + scf::YieldOp::create(builder, loc, desc.getFields()); desc.getFields().pop_back(); // Update fields and return next pos. @@ -381,17 +381,17 @@ static void genEndInsert(OpBuilder &builder, Location loc, Value oldv = loop.getRegionIterArg(0); Value newv = genLoad(builder, loc, posMemRef, i); Value posZero = constantZero(builder, loc, posType); - Value cond = builder.create( + Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, newv, posZero); - scf::IfOp ifOp = builder.create(loc, TypeRange(posType), + scf::IfOp ifOp = scf::IfOp::create(builder, loc, TypeRange(posType), cond, /*else*/ true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); genStore(builder, loc, oldv, posMemRef, i); - builder.create(loc, oldv); + scf::YieldOp::create(builder, loc, oldv); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, newv); + scf::YieldOp::create(builder, loc, newv); builder.setInsertionPointAfter(ifOp); - builder.create(loc, ifOp.getResult(0)); + scf::YieldOp::create(builder, loc, ifOp.getResult(0)); builder.setInsertionPointAfter(loop); } } else { @@ -484,7 +484,7 @@ class SparseInsertGenerator // if (isLooseCompressedLT(lt)) { Value two = constantIndex(builder, loc, 2); - parentPos = builder.create(loc, parentPos, two); + parentPos = arith::MulIOp::create(builder, loc, parentPos, two); } parentPos = genCompressed(builder, loc, desc, coords, value, parentPos, lvl); @@ -501,8 +501,8 @@ class SparseInsertGenerator // positions[lvl] = size * positions[lvl-1] + coords[lvl] // Value size = desc.getLvlSize(builder, loc, lvl); - Value mult = builder.create(loc, size, parentPos); - parentPos = builder.create(loc, mult, coords[lvl]); + Value mult = arith::MulIOp::create(builder, loc, size, parentPos); + parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]); } } // Reached the actual value append/insert. @@ -582,7 +582,7 @@ class SparseCallConverter : public OpConversionPattern { return failure(); // (1) Generates new call with flattened return value. - auto newCall = rewriter.create( + auto newCall = func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); // (2) Gather sparse tensor returns. SmallVector> packedResultVals; @@ -671,7 +671,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern { auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx); - rewriter.create(loc, nnz, crd, ValueRange{val}, id, + SortOp::create(rewriter, loc, nnz, crd, ValueRange{val}, id, rewriter.getIndexAttr(0), op.getAlgorithm()); // Since we do in-place sorting, the destinate tensor will have the same set @@ -757,10 +757,10 @@ class SparseTensorAllocConverter // Memcpy on memref fields. for (auto field : desc.getMemRefFields()) { auto memrefTp = cast(field.getType()); - auto size = rewriter.create(loc, field, 0); + auto size = memref::DimOp::create(rewriter, loc, field, 0); auto copied = - rewriter.create(loc, memrefTp, ValueRange{size}); - rewriter.create(loc, field, copied); + memref::AllocOp::create(rewriter, loc, memrefTp, ValueRange{size}); + memref::CopyOp::create(rewriter, loc, field, copied); fields.push_back(copied); } // Reuses specifier. @@ -863,7 +863,7 @@ class SparseTensorDeallocConverter cast(op.getTensor().getType())); for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. - rewriter.create(loc, input); + memref::DeallocOp::create(rewriter, loc, input); } rewriter.eraseOp(op); return success(); @@ -917,7 +917,7 @@ class SparseExpandConverter : public OpConversionPattern { // Generate a memref for `sz` elements of type `t`. const auto genAlloc = [&](Type t) { const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); - return rewriter.create(loc, memTp, ValueRange{sz}); + return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz}); }; // Allocate temporary buffers for values/filled-switch and added. // We do not use stack buffers for this, since the expanded size may @@ -931,10 +931,10 @@ class SparseExpandConverter : public OpConversionPattern { // operation is amortized over the innermost loops for the access // pattern expansion. As noted in the operation doc, we would like // to amortize this setup cost even between kernels. - rewriter.create( + linalg::FillOp::create(rewriter, loc, ValueRange{constantZero(rewriter, loc, eltType)}, ValueRange{values}); - rewriter.create( + linalg::FillOp::create(rewriter, loc, ValueRange{constantZero(rewriter, loc, boolType)}, ValueRange{filled}); // Replace expansion op with these buffers and initial coordinate. @@ -965,7 +965,7 @@ class SparseCompressConverter : public OpConversionPattern { // If the innermost level is ordered, we need to sort the coordinates // in the "added" array prior to applying the compression. if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) - rewriter.create( + SortOp::create(rewriter, loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); // While performing the insertions, we also need to reset the elements @@ -1000,15 +1000,15 @@ class SparseCompressConverter : public OpConversionPattern { SmallVector insertRet = insertGen.genCallOrInline(rewriter, loc); genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd); - rewriter.create(loc, insertRet); + scf::YieldOp::create(rewriter, loc, insertRet); rewriter.setInsertionPointAfter(loop); // Deallocate the buffers on exit of the full loop nest. Operation *parent = getTop(op); rewriter.setInsertionPointAfter(parent); - rewriter.create(loc, values); - rewriter.create(loc, filled); - rewriter.create(loc, added); + memref::DeallocOp::create(rewriter, loc, values); + memref::DeallocOp::create(rewriter, loc, filled); + memref::DeallocOp::create(rewriter, loc, added); // Replace operation with resulting memrefs. rewriter.replaceOpWithMultiple(op, {loop->getResults()}); return success(); @@ -1192,7 +1192,7 @@ class SparseConvertConverter : public OpConversionPattern { // would require a subViewOp to avoid overflow when copying // values. Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); - auto dstMem = rewriter.create( + auto dstMem = memref::AllocOp::create(rewriter, loc, cast(fTp), sz); if (fTp != srcMem.getType()) { // Converts elements type. @@ -1201,16 +1201,16 @@ class SparseConvertConverter : public OpConversionPattern { constantIndex(rewriter, loc, 1), [srcMem, &dstMem](OpBuilder &builder, Location loc, ValueRange ivs) { - Value v = builder.create(loc, srcMem, ivs); + Value v = memref::LoadOp::create(builder, loc, srcMem, ivs); Value casted = genCast(builder, loc, v, dstMem.getType().getElementType()); - builder.create(loc, casted, dstMem, ivs); + memref::StoreOp::create(builder, loc, casted, dstMem, ivs); }); } else { // TODO: We can even reuse the same memref for the new tensor, // but that requires a `ref-counting` based memory management // for shared memrefs between multiple sparse tensors. - rewriter.create(loc, srcMem, dstMem); + memref::CopyOp::create(rewriter, loc, srcMem, dstMem); } fields.push_back(dstMem); } @@ -1242,7 +1242,7 @@ class SparseExtractSliceConverter auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, op.getSource().getType()); - auto newSpec = rewriter.create( + auto newSpec = StorageSpecifierInitOp::create(rewriter, loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); desc.setSpecifier(newSpec); @@ -1326,11 +1326,11 @@ struct SparseAssembleOpConverter : public OpConversionPattern { // Flattens the buffer to batchLvlRank. auto reassoc = getReassociationForFlattening( mem.getType(), stt.getBatchLvlRank()); - mem = rewriter.create( + mem = memref::CastOp::create(rewriter, loc, fType, - rewriter.create(loc, mem, reassoc)); + memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc)); } else { - mem = rewriter.create(loc, fType, mem); + mem = memref::CastOp::create(rewriter, loc, fType, mem); } fields.push_back(mem); } @@ -1362,8 +1362,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern { LevelType lt = stt.getLvlType(lvl); // Simply forwards the position index when this is a dense level. if (lt.isa()) { - memSize = rewriter.create(loc, lvlSize, memSize); - posBack = rewriter.create(loc, memSize, c1); + memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize); + posBack = arith::SubIOp::create(rewriter, loc, memSize, c1); continue; } if (lt.isa()) { @@ -1376,12 +1376,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern { if (isWithPosLT(lt)) { assert(isCompressedLT(lt) || isLooseCompressedLT(lt)); if (isLooseCompressedLT(lt)) { - memSize = rewriter.create(loc, memSize, c2); - posBack = rewriter.create(loc, memSize, c1); + memSize = arith::MulIOp::create(rewriter, loc, memSize, c2); + posBack = arith::SubIOp::create(rewriter, loc, memSize, c1); } else { assert(isCompressedLT(lt)); posBack = memSize; - memSize = rewriter.create(loc, memSize, c1); + memSize = arith::AddIOp::create(rewriter, loc, memSize, c1); } desc.setPosMemSize(rewriter, loc, lvl, memSize); // The last value in position array is the memory size for next level. @@ -1391,12 +1391,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern { constantIndex(rewriter, loc, 0)); batched.push_back(posBack); memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched); - posBack = rewriter.create(loc, posBack, c1); + posBack = arith::SubIOp::create(rewriter, loc, posBack, c1); } assert(isWithCrdLT(lt) && lvl <= trailCOOStart); // FIXME: This seems to be unnecessarily complex, can we simplify it? if (lvl == trailCOOStart) { - Value cooSz = rewriter.create( + Value cooSz = arith::MulIOp::create(rewriter, loc, memSize, constantIndex(rewriter, loc, trailCOORank)); desc.setCrdMemSize(rewriter, loc, lvl, cooSz); } else { @@ -1460,18 +1460,18 @@ struct SparseDisassembleOpConverter if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) { auto reassoc = getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank()); - flatOut = rewriter.create(loc, dst, reassoc); + flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc); } Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz); Value srcMem = genSliceToSize(rewriter, loc, src, sz); - rewriter.create(loc, srcMem, dstMem); + memref::CopyOp::create(rewriter, loc, srcMem, dstMem); return true; }); // Converts MemRefs back to Tensors. SmallVector retValues = llvm::to_vector( llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value { - return rewriter.create( + return bufferization::ToTensorOp::create(rewriter, loc, memref::getTensorTypeFromMemRefType(v.getType()), v); })); // Appends the actual memory length used in each buffer returned. @@ -1549,13 +1549,13 @@ struct SparseNewConverter : public OpConversionPattern { const Level lvlRank = dstTp.getLvlRank(); if (dstTp.isOrderedLvl(lvlRank - 1)) { Value kFalse = constantI1(rewriter, loc, false); - Value notSorted = rewriter.create( + Value notSorted = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse); scf::IfOp ifOp = - rewriter.create(loc, notSorted, /*else*/ false); + scf::IfOp::create(rewriter, loc, notSorted, /*else*/ false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank); - rewriter.create(loc, nse, xs, ValueRange{ys}, xPerm, + SortOp::create(rewriter, loc, nse, xs, ValueRange{ys}, xPerm, rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); rewriter.setInsertionPointAfter(ifOp); @@ -1566,10 +1566,10 @@ struct SparseNewConverter : public OpConversionPattern { const Value posMemref0 = desc.getPosMemRef(0); const Type posTp = dstTp.getPosType(); const Value posNse = genCast(rewriter, loc, nse, posTp); - rewriter.create(loc, posNse, posMemref0, c1); + memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1); // Update storage specifier. - Value coordinatesSize = rewriter.create( + Value coordinatesSize = arith::MulIOp::create(rewriter, loc, nse, constantIndex(rewriter, loc, lvlRank)); desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0, coordinatesSize); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 50ccb43d432b6..54ea1843b457f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -137,7 +137,7 @@ static SmallVector getDimSizes(OpBuilder &builder, Location loc, /// this buffer must be explicitly deallocated by client. static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); - return rewriter.create(loc, memTp, ValueRange{sz}); + return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz}); } /// Generates a temporary buffer for the level-types of the given encoding. @@ -154,7 +154,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, Value tensor) { auto buf = genToMemref(builder, loc, tensor); - return builder.create(loc, buf); + return memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, buf); } /// Generates a temporary buffer for the level-types of the given encoding. @@ -168,11 +168,11 @@ static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, // Passing in value buffer pointers. lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); - Value idxPtr = builder.create( + Value idxPtr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, allocaBuffer(builder, loc, lvlBarePtrs)); Value idxCast = - builder.create(loc, builder.getI64Type(), idxPtr); - return builder.create(loc, getOpaquePointerType(builder), + arith::IndexCastOp::create(builder, loc, builder.getI64Type(), idxPtr); + return LLVM::IntToPtrOp::create(builder, loc, getOpaquePointerType(builder), idxCast); } @@ -227,7 +227,7 @@ class NewCallParams final { assert(isInitialized() && "Must initialize before genNewCall"); StringRef name = "newSparseTensor"; params[kParamAction] = constantAction(builder, loc, action); - params[kParamPtr] = ptr ? ptr : builder.create(loc, pTp); + params[kParamPtr] = ptr ? ptr : LLVM::ZeroOp::create(builder, loc, pTp); return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) .getResult(0); } @@ -539,7 +539,7 @@ class SparseTensorToCoordinatesConverter // Cast the MemRef type to the type expected by the users, though these // two types should be compatible at runtime. if (op.getType() != crds.getType()) - crds = rewriter.create(loc, op.getType(), crds); + crds = memref::CastOp::create(rewriter, loc, op.getType(), crds); rewriter.replaceOp(op, crds); return success(); } @@ -560,7 +560,7 @@ class SparseToCoordinatesBufferConverter // Cast the MemRef type to the type expected by the users, though these // two types should be compatible at runtime. if (op.getType() != crds.getType()) - crds = rewriter.create(loc, op.getType(), crds); + crds = memref::CastOp::create(rewriter, loc, op.getType(), crds); rewriter.replaceOp(op, crds); return success(); } @@ -652,7 +652,7 @@ class SparseTensorInsertConverter vref = genAllocaScalar(rewriter, loc, elemTp); } storeAll(rewriter, loc, lvlCoords, adaptor.getIndices()); - rewriter.create(loc, adaptor.getScalar(), vref); + memref::StoreOp::create(rewriter, loc, adaptor.getScalar(), vref); SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On); @@ -690,10 +690,10 @@ class SparseTensorExpandConverter : public OpConversionPattern { // operation is amortized over the innermost loops for the access // pattern expansion. As noted in the operation doc, we would like // to amortize this setup cost even between kernels. - rewriter.create( + linalg::FillOp::create(rewriter, loc, ValueRange{constantZero(rewriter, loc, eltType)}, ValueRange{values}); - rewriter.create( + linalg::FillOp::create(rewriter, loc, ValueRange{constantZero(rewriter, loc, boolType)}, ValueRange{filled}); // Replace expansion op with these buffers and initial coordinate. @@ -733,9 +733,9 @@ class SparseTensorCompressConverter : public OpConversionPattern { rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. rewriter.setInsertionPointAfter(parent); - rewriter.create(loc, values); - rewriter.create(loc, filled); - rewriter.create(loc, added); + memref::DeallocOp::create(rewriter, loc, values); + memref::DeallocOp::create(rewriter, loc, filled); + memref::DeallocOp::create(rewriter, loc, added); return success(); } }; @@ -837,21 +837,21 @@ class SparseTensorDisassembleConverter cooStartLvl + 1); auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0); auto two = constantIndex(rewriter, loc, 2); - auto bufLen = rewriter.create(loc, crdLen, two); + auto bufLen = arith::MulIOp::create(rewriter, loc, crdLen, two); Type indexType = rewriter.getIndexType(); auto zero = constantZero(rewriter, loc, indexType); auto one = constantOne(rewriter, loc, indexType); - scf::ForOp forOp = rewriter.create(loc, zero, crdLen, one); + scf::ForOp forOp = scf::ForOp::create(rewriter, loc, zero, crdLen, one); auto idx = forOp.getInductionVar(); rewriter.setInsertionPointToStart(forOp.getBody()); - auto c0 = rewriter.create(loc, crds0, idx); - auto c1 = rewriter.create(loc, crds1, idx); + auto c0 = memref::LoadOp::create(rewriter, loc, crds0, idx); + auto c1 = memref::LoadOp::create(rewriter, loc, crds1, idx); SmallVector args; args.push_back(idx); args.push_back(zero); - rewriter.create(loc, c0, buf, args); + memref::StoreOp::create(rewriter, loc, c0, buf, args); args[1] = one; - rewriter.create(loc, c1, buf, args); + memref::StoreOp::create(rewriter, loc, c1, buf, args); rewriter.setInsertionPointAfter(forOp); auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()]; retVal.push_back(buf); @@ -867,11 +867,11 @@ class SparseTensorDisassembleConverter // Converts MemRefs back to Tensors. assert(retVal.size() + retLen.size() == op.getNumResults()); for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { - auto tensor = rewriter.create( + auto tensor = bufferization::ToTensorOp::create(rewriter, loc, memref::getTensorTypeFromMemRefType(retVal[i].getType()), retVal[i]); retVal[i] = - rewriter.create(loc, op.getResultTypes()[i], tensor); + tensor::CastOp::create(rewriter, loc, op.getResultTypes()[i], tensor); } // Appends the actual memory length used in each buffer returned. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index d4a02bf7a70b6..ee963ab5d1ffb 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -127,7 +127,7 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl &sizes, for (const auto &d : enumerate(stp.getShape())) { Value dim; if (d.value() == ShapedType::kDynamic) - dim = builder.create(loc, tensor, d.index()); + dim = tensor::DimOp::create(builder, loc, tensor, d.index()); else dim = constantIndex(builder, loc, d.value()); sizes.push_back(dim); @@ -198,7 +198,7 @@ static void concatSizesFromInputs(OpBuilder &builder, for (const auto &src : srcs.drop_front()) { Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim); // Sum up all the sizes. - sizes[dim] = builder.create(loc, sizes[dim], srcSz); + sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz); } } } @@ -405,7 +405,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { inputOps.push_back(op.getDpsInputOperand(1 - other)->get()); fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other // Fuse producer and consumer into a new generic op. - auto fusedOp = rewriter.create( + auto fusedOp = GenericOp::create(rewriter, loc, op.getResult(0).getType(), inputOps, outputOps, rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); @@ -430,7 +430,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); last = rewriter.clone(*acc, mapper)->getResult(0); - rewriter.create(loc, last); + linalg::YieldOp::create(rewriter, loc, last); // Force initial value on merged allocation for dense outputs. // TODO: deal with non alloc tensor here one day if (!getSparseTensorEncoding(op.getResult(0).getType())) { @@ -534,7 +534,7 @@ struct GenSemiRingSelect : public OpRewritePattern { assert(t.getType() == f.getType()); auto selTp = t.getType(); auto c0 = constantZero(rewriter, loc, selTp); - auto binOp = rewriter.create(loc, selTp, t, f); + auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f); // Initializes all the blocks. rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp}, {t.getLoc(), f.getLoc()}); @@ -564,7 +564,7 @@ struct GenSemiRingSelect : public OpRewritePattern { irMap.map(f, b->getArgument(1)); } auto y = rewriter.clone(inst, irMap)->getResult(0); - rewriter.create(loc, y); + sparse_tensor::YieldOp::create(rewriter, loc, y); } // We successfully rewrited a operation. We can not do replacement here @@ -674,28 +674,28 @@ struct GenSemiRingReduction : public OpRewritePattern { // Identity. Location loc = op.getLoc(); Value identity = - rewriter.create(loc, init->get(), ValueRange()); + tensor::ExtractOp::create(rewriter, loc, init->get(), ValueRange()); // Unary { // present -> value // absent -> zero. // } Type rtp = s0.getType(); rewriter.setInsertionPointToStart(&op.getRegion().front()); - auto semiring = rewriter.create(loc, rtp, s0); + auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0); Block *present = rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc); rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); - rewriter.create(loc, present->getArgument(0)); + sparse_tensor::YieldOp::create(rewriter, loc, present->getArgument(0)); rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {}); rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front()); auto zero = - rewriter.create(loc, rewriter.getZeroAttr(rtp)); - rewriter.create(loc, zero); + arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(rtp)); + sparse_tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(semiring); // CustomReduce { // x = x REDUC y, identity // } - auto custom = rewriter.create( + auto custom = sparse_tensor::ReduceOp::create(rewriter, loc, rtp, semiring.getResult(), s1, identity); Block *region = rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc}); @@ -704,7 +704,7 @@ struct GenSemiRingReduction : public OpRewritePattern { irMap.map(red->getOperand(0), region->getArgument(0)); irMap.map(red->getOperand(1), region->getArgument(1)); auto *cloned = rewriter.clone(*red, irMap); - rewriter.create(loc, cloned->getResult(0)); + sparse_tensor::YieldOp::create(rewriter, loc, cloned->getResult(0)); rewriter.setInsertionPointAfter(custom); rewriter.replaceOp(red, custom.getResult()); return success(); @@ -723,14 +723,14 @@ struct PrintRewriter : public OpRewritePattern { auto tensor = op.getTensor(); auto stt = getSparseTensorType(tensor); // Header with NSE. - auto nse = rewriter.create(loc, tensor); - rewriter.create( + auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor); + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = ")); - rewriter.create(loc, nse); + vector::PrintOp::create(rewriter, loc, nse); // Print run-time contents for dim/lvl sizes. - rewriter.create(loc, rewriter.getStringAttr("dim = ")); + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("dim = ")); printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true); - rewriter.create(loc, rewriter.getStringAttr("lvl = ")); + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("lvl = ")); printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false); // Use the "codegen" foreach loop construct to iterate over // all typical sparse tensor components for printing. @@ -744,42 +744,42 @@ struct PrintRewriter : public OpRewritePattern { } case SparseTensorFieldKind::PosMemRef: { auto lvl = constantIndex(rewriter, loc, l); - rewriter.create(loc, rewriter.getStringAttr("pos[")); - rewriter.create( + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("pos[")); + vector::PrintOp::create(rewriter, loc, lvl, vector::PrintPunctuation::NoPunctuation); - rewriter.create(loc, rewriter.getStringAttr("] : ")); - auto pos = rewriter.create(loc, tensor, l); + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : ")); + auto pos = ToPositionsOp::create(rewriter, loc, tensor, l); printContents(rewriter, loc, pos); break; } case SparseTensorFieldKind::CrdMemRef: { auto lvl = constantIndex(rewriter, loc, l); - rewriter.create(loc, rewriter.getStringAttr("crd[")); - rewriter.create( + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("crd[")); + vector::PrintOp::create(rewriter, loc, lvl, vector::PrintPunctuation::NoPunctuation); - rewriter.create(loc, rewriter.getStringAttr("] : ")); + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : ")); Value crd = nullptr; // For COO AoS storage, we want to print a single, linear view of // the full coordinate storage at this level. For any other storage, // we show the coordinate storage for every indivual level. if (stt.getAoSCOOStart() == l) - crd = rewriter.create(loc, tensor); + crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor); else - crd = rewriter.create(loc, tensor, l); + crd = ToCoordinatesOp::create(rewriter, loc, tensor, l); printContents(rewriter, loc, crd); break; } case SparseTensorFieldKind::ValMemRef: { - rewriter.create(loc, + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("values : ")); - auto val = rewriter.create(loc, tensor); + auto val = ToValuesOp::create(rewriter, loc, tensor); printContents(rewriter, loc, val); break; } } return true; }); - rewriter.create(loc, rewriter.getStringAttr("----\n")); + vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("----\n")); rewriter.eraseOp(op); return success(); } @@ -797,7 +797,7 @@ struct PrintRewriter : public OpRewritePattern { auto shape = cast(vec.getType()).getShape(); SmallVector idxs; printContentsLevel(rewriter, loc, vec, 0, shape, idxs); - rewriter.create(loc, vector::PrintPunctuation::NewLine); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine); } // Helper to the helper. @@ -805,13 +805,13 @@ struct PrintRewriter : public OpRewritePattern { Value vec, unsigned i, ArrayRef shape, SmallVectorImpl &idxs) { // Open bracket. - rewriter.create(loc, vector::PrintPunctuation::Open); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); // Generate for loop. auto zero = constantIndex(rewriter, loc, 0); auto index = constantIndex(rewriter, loc, i); - auto size = rewriter.create(loc, vec, index); + auto size = memref::DimOp::create(rewriter, loc, vec, index); auto step = constantIndex(rewriter, loc, 1); - auto forOp = rewriter.create(loc, zero, size, step); + auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step); idxs.push_back(forOp.getInductionVar()); rewriter.setInsertionPointToStart(forOp.getBody()); if (i < shape.size() - 1) { @@ -819,56 +819,56 @@ struct PrintRewriter : public OpRewritePattern { printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs); } else { // Actual contents printing. - auto val = rewriter.create(loc, vec, idxs); + auto val = memref::LoadOp::create(rewriter, loc, vec, idxs); if (llvm::isa(val.getType())) { // Since the vector dialect does not support complex types in any op, // we split those into (real, imag) pairs here. - Value real = rewriter.create(loc, val); - Value imag = rewriter.create(loc, val); - rewriter.create(loc, vector::PrintPunctuation::Open); - rewriter.create(loc, real, + Value real = complex::ReOp::create(rewriter, loc, val); + Value imag = complex::ImOp::create(rewriter, loc, val); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); + vector::PrintOp::create(rewriter, loc, real, vector::PrintPunctuation::Comma); - rewriter.create(loc, imag, + vector::PrintOp::create(rewriter, loc, imag, vector::PrintPunctuation::Close); } else { - rewriter.create( + vector::PrintOp::create(rewriter, loc, val, vector::PrintPunctuation::NoPunctuation); } // Terminating comma (except at end). - auto bound = rewriter.create(loc, idxs.back(), step); - Value cond = rewriter.create(loc, arith::CmpIPredicate::ne, + auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step); + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, bound, size); - scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond, /*else*/ false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - rewriter.create(loc, vector::PrintPunctuation::Comma); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma); } idxs.pop_back(); rewriter.setInsertionPointAfter(forOp); // Close bracket. - rewriter.create(loc, vector::PrintPunctuation::Close); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close); } // Helper method to print run-time lvl/dim sizes. static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor, unsigned size, bool isDim) { // Open bracket. - rewriter.create(loc, vector::PrintPunctuation::Open); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); // Print unrolled contents (dimop requires constant value). for (unsigned i = 0; i < size; i++) { auto idx = constantIndex(rewriter, loc, i); Value val; if (isDim) - val = rewriter.create(loc, tensor, idx); + val = tensor::DimOp::create(rewriter, loc, tensor, idx); else - val = rewriter.create(loc, tensor, idx); - rewriter.create( + val = LvlOp::create(rewriter, loc, tensor, idx); + vector::PrintOp::create(rewriter, loc, val, i != size - 1 ? vector::PrintPunctuation::Comma : vector::PrintPunctuation::NoPunctuation); } // Close bracket and end of line. - rewriter.create(loc, vector::PrintPunctuation::Close); - rewriter.create(loc, vector::PrintPunctuation::NewLine); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine); } }; @@ -896,7 +896,7 @@ struct TensorReshapeRewriter : public OpRewritePattern { for (Dimension d : dstTp->getDimShape()) dstSizes.push_back(constantIndex(rewriter, loc, d)); - Value nnz = rewriter.create(loc, srcTensor); + Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor); // Only need an unordered COO buffer if input and output are not sorted // in the same way. Type bufferTp = getBufferType( @@ -920,7 +920,7 @@ struct TensorReshapeRewriter : public OpRewritePattern { // %t = sparse_tensor.cast %tmp // depending on whether the input/output are sorted in the same way. const auto encSrc = srcTp->getEncoding(); - ForeachOp foreachOp = rewriter.create( + ForeachOp foreachOp = ForeachOp::create(rewriter, loc, srcTensor, buffer, [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, ValueRange reduc) { @@ -935,7 +935,7 @@ struct TensorReshapeRewriter : public OpRewritePattern { Value collapseSize = constantIndex(builder, loc, 1); for (Dimension d = 0; d < srcRank; d++) collapseSize = - builder.create(loc, collapseSize, srcSizes[d]); + arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]); SmallVector collapsedSizes = {collapseSize}; ReassociationIndices collapseIdx; @@ -955,15 +955,15 @@ struct TensorReshapeRewriter : public OpRewritePattern { dstSizes, dstDcvs); auto t = - builder.create(loc, v, reduc.front(), dstDcvs); - builder.create(loc, t); + tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs); + sparse_tensor::YieldOp::create(builder, loc, t); }); - Value t = rewriter.create(loc, foreachOp.getResult(0), true); + Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true); if (bufferTp != *dstTp) { auto dstRTT = dstTp->getRankedTensorType(); - Value converted = rewriter.create(loc, dstRTT, t).getResult(); - rewriter.create(loc, t); + Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult(); + DeallocTensorOp::create(rewriter, loc, t); t = converted; } rewriter.replaceOp(op, t); @@ -1004,7 +1004,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern { dstDynSizes.push_back(dstSizes[idx]); } } - Value nnz = rewriter.create(loc, srcTensor); + Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor); // Only need a unordered COO buffer if input and output are not sorted // in the same way. Type bufferTp = getBufferType( @@ -1025,7 +1025,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern { // %t = sparse_tensor.cast %tmp // depending on whether the input/output are sorted in the same way. const auto encSrc = srcTp.getEncoding(); - ForeachOp foreachOp = rewriter.create( + ForeachOp foreachOp = ForeachOp::create(rewriter, loc, srcTensor, buffer, [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, ValueRange reduc) { @@ -1040,15 +1040,15 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern { reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes, srcDcvs, dstSizes, dstDcvs); auto t = - builder.create(loc, v, reduc.front(), dstDcvs); - builder.create(loc, t); + tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs); + sparse_tensor::YieldOp::create(builder, loc, t); }); - Value t = rewriter.create(loc, foreachOp.getResult(0), true); + Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true); if (bufferTp != dstTp) { auto dstRTT = dstTp.getRankedTensorType(); - Value converted = rewriter.create(loc, dstRTT, t).getResult(); - rewriter.create(loc, t); + Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult(); + DeallocTensorOp::create(rewriter, loc, t); t = converted; } rewriter.replaceOp(op, t); @@ -1079,7 +1079,7 @@ struct ReshapeRewriter : public OpRewritePattern { auto rtp = getRankedTensorType(op.getSrc()); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto convert = rewriter.create(loc, denseTp, op.getSrc()); + auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc()); rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); }); return success(); } @@ -1089,14 +1089,14 @@ struct ReshapeRewriter : public OpRewritePattern { RankedTensorType::get(rtp.getShape(), rtp.getElementType()); ReshapeOp reshape; if constexpr (std::is_same::value) { - reshape = rewriter.create( + reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(), op.getReassociation(), op.getOutputShape(), op.getStaticOutputShape()); } else { - reshape = rewriter.create(loc, denseTp, op.getSrc(), + reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(), op.getReassociation()); } - Value convert = rewriter.create(loc, rtp, reshape); + Value convert = ConvertOp::create(rewriter, loc, rtp, reshape); rewriter.replaceOp(op, convert); return success(); } @@ -1112,20 +1112,20 @@ struct TensorLike { SmallVector dynSzs; getDynamicSizes(rtt, sizes, dynSzs); - val = builder.create(loc, rtt, dynSzs); + val = AllocTensorOp::create(builder, loc, rtt, dynSzs); if (!isSparse()) { Value c0 = constantZero(builder, loc, rtt.getElementType()); - val = builder.create(loc, c0, val).getResult(0); + val = linalg::FillOp::create(builder, loc, c0, val).getResult(0); } } void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) { - val = builder.create(loc, v, val, crds); + val = tensor::InsertOp::create(builder, loc, v, val, crds); } Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { if (isSparse()) - return builder.create(loc, val, true); + return LoadOp::create(builder, loc, val, true); return val; } @@ -1160,18 +1160,18 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern { Location loc = op.getLoc(); SmallVector maxLvlCrds; for (Level l = 0; l < stt->getLvlRank(); l++) { - Value lvlSz = rewriter.create(loc, op.getSource(), l); - Value maxLvlCrd = rewriter.create( + Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l); + Value maxLvlCrd = arith::SubIOp::create(rewriter, loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType())); maxLvlCrds.push_back(maxLvlCrd); } AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim); - Value maxDimCrd = rewriter.create( + Value maxDimCrd = affine::AffineApplyOp::create(rewriter, op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp), maxLvlCrds); - Value dimSz = rewriter.create( + Value dimSz = arith::AddIOp::create(rewriter, loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType())); rewriter.replaceOp(op, dimSz); return success(); @@ -1212,26 +1212,26 @@ struct ConcatenateRewriter : public OpRewritePattern { for (Value input : op.getInputs()) { // Builds a for op for each input tensor to append new values into the // output tensor. - foreachOp = rewriter.create( + foreachOp = ForeachOp::create(rewriter, loc, input, iterArg, [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, ValueRange reduc) { SmallVector offDimCrd(dcvs); offDimCrd[conDim] = - builder.create(loc, offDimCrd[conDim], offset); + arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset); // Enters foreach, updates the SSA chain. dstBuf.val = reduc.front(); if (!dstTp.isAllDense()) { Value cond = genIsNonzero(builder, loc, v); - auto ifOp = builder.create(loc, reduc.getTypes(), cond, + auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond, /*else*/ true); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, dstBuf.val); + scf::YieldOp::create(builder, loc, dstBuf.val); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); dstBuf.insert(builder, loc, v, offDimCrd); - builder.create(loc, dstBuf.val); + scf::YieldOp::create(builder, loc, dstBuf.val); // Exits the ifOp, update the sparse tensor SSA value. builder.setInsertionPointAfter(ifOp); @@ -1239,14 +1239,14 @@ struct ConcatenateRewriter : public OpRewritePattern { } else { dstBuf.insert(builder, loc, v, offDimCrd); } - builder.create(loc, dstBuf.val); + sparse_tensor::YieldOp::create(builder, loc, dstBuf.val); }); // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset // dynamically. const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim); assert(ShapedType::isStatic(sz)); - offset = rewriter.create(loc, offset, + offset = arith::AddIOp::create(rewriter, loc, offset, constantIndex(rewriter, loc, sz)); iterArg = foreachOp.getResult(0); dstBuf.val = iterArg; @@ -1299,7 +1299,7 @@ struct DirectConvertRewriter : public OpRewritePattern { ValueRange vs; TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes); - auto foreachOp = rewriter.create( + auto foreachOp = ForeachOp::create(rewriter, loc, src, dstBuf.val, foreachOrder, [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, ValueRange reduc) { @@ -1307,14 +1307,14 @@ struct DirectConvertRewriter : public OpRewritePattern { dstBuf.val = reduc.front(); if (!skipZeroCheck) { Value cond = genIsNonzero(builder, loc, v); - auto ifOp = builder.create(loc, reduc.getTypes(), cond, + auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond, /*else*/ true); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, dstBuf.val); + scf::YieldOp::create(builder, loc, dstBuf.val); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); dstBuf.insert(builder, loc, v, dcvs); - builder.create(loc, dstBuf.val); + scf::YieldOp::create(builder, loc, dstBuf.val); // Exits the ifOp, update the sparse tensor SSA value. builder.setInsertionPointAfter(ifOp); @@ -1322,7 +1322,7 @@ struct DirectConvertRewriter : public OpRewritePattern { } else { dstBuf.insert(builder, loc, v, dcvs); } - builder.create(loc, dstBuf.val); + sparse_tensor::YieldOp::create(builder, loc, dstBuf.val); }); rewriter.setInsertionPointAfter(foreachOp); @@ -1349,7 +1349,7 @@ struct CrdTranslateRewriter : public OpRewritePattern { // TODO: we should probably expand the affine map to IR using our own // rules, since affine.apply assume signed value, while the cooridinates // we provided must always be signless. - Value trans = rewriter.create( + Value trans = affine::AffineApplyOp::create(rewriter, op.getLoc(), AffineMap::get(map.getNumDims(), 0, result), op.getInCrds()); outCrds.push_back(trans); @@ -1412,8 +1412,8 @@ struct ForeachRewriter : public OpRewritePattern { SmallVector pos = loopEmitter.getValPosits(0); // Loads the value from sparse tensor using position-index; // loads the value from dense tensor using coords. - Value val = enc ? rewriter.create(loc, vals, pos) - : rewriter.create(loc, vals, lcvs); + Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos) + : memref::LoadOp::create(rewriter, loc, vals, lcvs); // 2. Inline the block in the foreach operator. Block *srcBlock = op.getBody(); @@ -1472,22 +1472,22 @@ struct NewRewriter : public OpRewritePattern { // with enveloping reinterpreted_map ops for non-permutations. RankedTensorType dstTp = stt.getRankedTensorType(); RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true); - Value cooTensor = rewriter.create(loc, cooTp, op.getSource()); + Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource()); Value convert = cooTensor; auto enc = stt.getEncoding(); if (!stt.isPermutation()) { // demap coo, demap dstTp auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl(); - convert = rewriter.create(loc, coo, convert); + convert = ReinterpretMapOp::create(rewriter, loc, coo, convert); dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl()); } - convert = rewriter.create(loc, dstTp, convert); + convert = ConvertOp::create(rewriter, loc, dstTp, convert); if (!stt.isPermutation()) // remap to original enc - convert = rewriter.create(loc, enc, convert); + convert = ReinterpretMapOp::create(rewriter, loc, enc, convert); rewriter.replaceOp(op, convert); // Release the temporary ordered COO tensor. rewriter.setInsertionPointAfterValue(convert); - rewriter.create(loc, cooTensor); + DeallocTensorOp::create(rewriter, loc, cooTensor); return success(); } @@ -1501,7 +1501,7 @@ struct OutRewriter : public OpRewritePattern { Location loc = op.getLoc(); // Calculate NNZ. Value src = op.getTensor(); - Value nnz = rewriter.create(loc, src); + Value nnz = NumberOfEntriesOp::create(rewriter, loc, src); // Allocate a temporary buffer for storing dimension-sizes/coordinates. const auto srcTp = getSparseTensorType(src); @@ -1514,7 +1514,7 @@ struct OutRewriter : public OpRewritePattern { SmallVector dims; sizesForTensor(rewriter, dims, loc, srcTp, src); for (Dimension d = 0; d < dimRank; d++) { - rewriter.create(loc, dims[d], dimSizes, + memref::StoreOp::create(rewriter, loc, dims[d], dimSizes, constantIndex(rewriter, loc, d)); } @@ -1536,20 +1536,20 @@ struct OutRewriter : public OpRewritePattern { ModuleOp module = op->getParentOfType(); // For each element in the source tensor, output the element. - rewriter.create( + ForeachOp::create(rewriter, loc, src, ValueRange(), [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, ValueRange reduc) { for (Dimension d = 0; d < dimRank; d++) { - rewriter.create(loc, dcvs[d], dimCoords, + memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords, constantIndex(builder, loc, d)); } - rewriter.create(loc, v, value); + memref::StoreOp::create(rewriter, loc, v, value); SmallVector operands{writer, rankValue, dimCoords, value}; FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, EmitCInterface::On); - builder.create(loc, TypeRange(), fn, operands); - builder.create(loc); + func::CallOp::create(builder, loc, TypeRange(), fn, operands); + sparse_tensor::YieldOp::create(builder, loc); }); // Release the writer. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 359590f2434dc..5e1f872d5b185 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -79,7 +79,7 @@ static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, matchPattern(step, m_Constant(&stepInt))) { if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { Value trueVal = constantI1(rewriter, loc, true); - return rewriter.create(loc, mtp, trueVal); + return vector::BroadcastOp::create(rewriter, loc, mtp, trueVal); } } // Otherwise, generate a vector mask that avoids overrunning the upperbound @@ -93,7 +93,7 @@ static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, rewriter.getContext()); Value end = rewriter.createOrFold( loc, min, ValueRange{hi, iv, step}); - return rewriter.create(loc, mtp, end); + return vector::CreateMaskOp::create(rewriter, loc, mtp, end); } /// Generates a vectorized invariant. Here we rely on subsequent loop @@ -101,7 +101,7 @@ static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, Value val) { VectorType vtp = vectorType(vl, val.getType()); - return rewriter.create(val.getLoc(), vtp, val); + return vector::BroadcastOp::create(rewriter, val.getLoc(), vtp, val); } /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], @@ -116,10 +116,10 @@ static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, SmallVector scalarArgs(idxs); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); - return rewriter.create(loc, vtp, mem, scalarArgs, + return vector::GatherOp::create(rewriter, loc, vtp, mem, scalarArgs, indexVec, vmask, pass); } - return rewriter.create(loc, vtp, mem, idxs, vmask, + return vector::MaskedLoadOp::create(rewriter, loc, vtp, mem, idxs, vmask, pass); } @@ -133,11 +133,11 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, SmallVector scalarArgs(idxs); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); - rewriter.create(loc, mem, scalarArgs, indexVec, vmask, + vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask, rhs); return; } - rewriter.create(loc, mem, idxs, vmask, rhs); + vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs); } /// Detects a vectorizable reduction operations and returns the @@ -198,18 +198,18 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, case vector::CombiningKind::ADD: case vector::CombiningKind::XOR: // Initialize reduction vector to: | 0 | .. | 0 | r | - return rewriter.create(loc, r, + return vector::InsertOp::create(rewriter, loc, r, constantZero(rewriter, loc, vtp), constantIndex(rewriter, loc, 0)); case vector::CombiningKind::MUL: // Initialize reduction vector to: | 1 | .. | 1 | r | - return rewriter.create(loc, r, + return vector::InsertOp::create(rewriter, loc, r, constantOne(rewriter, loc, vtp), constantIndex(rewriter, loc, 0)); case vector::CombiningKind::AND: case vector::CombiningKind::OR: // Initialize reduction vector to: | r | .. | r | r | - return rewriter.create(loc, vtp, r); + return vector::BroadcastOp::create(rewriter, loc, vtp, r); default: break; } @@ -301,10 +301,10 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, Type etp = llvm::cast(vload.getType()).getElementType(); if (!llvm::isa(etp)) { if (etp.getIntOrFloatBitWidth() < 32) - vload = rewriter.create( + vload = arith::ExtUIOp::create(rewriter, loc, vectorType(vl, rewriter.getI32Type()), vload); else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) - vload = rewriter.create( + vload = arith::ExtUIOp::create(rewriter, loc, vectorType(vl, rewriter.getI64Type()), vload); } idxs.push_back(vload); @@ -329,7 +329,7 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, return false; if (codegen) idxs.push_back( - rewriter.create(forOp.getLoc(), inv, idx)); + arith::AddIOp::create(rewriter, forOp.getLoc(), inv, idx)); continue; // success so far } } @@ -342,7 +342,7 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, #define UNAOP(xxx) \ if (isa(def)) { \ if (codegen) \ - vexp = rewriter.create(loc, vx); \ + vexp = xxx::create(rewriter, loc, vx); \ return true; \ } @@ -350,7 +350,7 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, if (auto x = dyn_cast(def)) { \ if (codegen) { \ VectorType vtp = vectorType(vl, x.getType()); \ - vexp = rewriter.create(loc, vtp, vx); \ + vexp = xxx::create(rewriter, loc, vtp, vx); \ } \ return true; \ } @@ -358,7 +358,7 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, #define BINOP(xxx) \ if (isa(def)) { \ if (codegen) \ - vexp = rewriter.create(loc, vx, vy); \ + vexp = xxx::create(rewriter, loc, vx, vy); \ return true; \ } @@ -381,9 +381,9 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, // such as a[i] = i, which must convert to [i, i+1, ...]. if (codegen) { VectorType vtp = vectorType(vl, arg.getType()); - Value veci = rewriter.create(loc, vtp, arg); - Value incr = rewriter.create(loc, vtp); - vexp = rewriter.create(loc, veci, incr); + Value veci = vector::BroadcastOp::create(rewriter, loc, vtp, arg); + Value incr = vector::StepOp::create(rewriter, loc, vtp); + vexp = arith::AddIOp::create(rewriter, loc, veci, incr); } return true; } @@ -526,15 +526,15 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, Value step = constantIndex(rewriter, loc, vl.vectorLength); if (vl.enableVLAVectorization) { Value vscale = - rewriter.create(loc, rewriter.getIndexType()); - step = rewriter.create(loc, vscale, step); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + step = arith::MulIOp::create(rewriter, loc, vscale, step); } if (!yield.getResults().empty()) { Value init = forOp.getInitArgs()[0]; VectorType vtp = vectorType(vl, init.getType()); Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), forOp.getRegionIterArg(0), init, vtp); - forOpNew = rewriter.create( + forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); forOpNew->setAttr( LoopEmitter::getLoopEmitterLoopAttrName(), @@ -563,10 +563,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, if (codegen) { Value partial = forOpNew.getResult(0); Value vpass = genVectorInvariantValue(rewriter, vl, iter); - Value vred = rewriter.create(loc, vmask, vrhs, vpass); - rewriter.create(loc, vred); + Value vred = arith::SelectOp::create(rewriter, loc, vmask, vrhs, vpass); + scf::YieldOp::create(rewriter, loc, vred); rewriter.setInsertionPointAfter(forOpNew); - Value vres = rewriter.create(loc, kind, partial); + Value vres = vector::ReductionOp::create(rewriter, loc, kind, partial); // Now do some relinking (last one is not completely type safe // but all bad ones are removed right away). This also folds away // nop broadcast operations. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 2d604ed7a8ffc..0cabe5498d7f4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -322,7 +322,7 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) { if (!isInit) { Value zero = constantZero(builder, loc, getElementTypeOrSelf(tensor.getType())); - builder.create(loc, ValueRange{zero}, + linalg::FillOp::create(builder, loc, ValueRange{zero}, ValueRange{init}); } return init; @@ -385,7 +385,7 @@ static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, } // Load from expanded access pattern. Value index = genIndex(env, t); - return builder.create(loc, env.getExpandValues(), index); + return memref::LoadOp::create(builder, loc, env.getExpandValues(), index); } /// Generates insertion code to implement dynamic tensor load for reduction. @@ -401,22 +401,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, Value values = env.getExpandValues(); Value filled = env.getExpandFilled(); Value index = genIndex(env, t); - Value isFilled = builder.create(loc, filled, index); - Value valAtIndex = builder.create(loc, values, index); - return builder.create(loc, isFilled, valAtIndex, identity); + Value isFilled = memref::LoadOp::create(builder, loc, filled, index); + Value valAtIndex = memref::LoadOp::create(builder, loc, values, index); + return arith::SelectOp::create(builder, loc, isFilled, valAtIndex, identity); } static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v) { scf::IfOp condInsert = - builder.create(loc, sparseOut.getType(), cond, true); + scf::IfOp::create(builder, loc, sparseOut.getType(), cond, true); // True branch. builder.setInsertionPointToStart(condInsert.thenBlock()); - Value res = builder.create(loc, v, sparseOut, ivs); - builder.create(loc, res); + Value res = tensor::InsertOp::create(builder, loc, v, sparseOut, ivs); + scf::YieldOp::create(builder, loc, res); // False branch. builder.setInsertionPointToStart(condInsert.elseBlock()); - builder.create(loc, sparseOut); + scf::YieldOp::create(builder, loc, sparseOut); // Value assignment. builder.setInsertionPointAfter(condInsert); return condInsert.getResult(0); @@ -453,7 +453,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value nz = genIsNonzero(builder, loc, rhs); sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs); } else { - sparseOut = builder.create(loc, rhs, chain, ivs); + sparseOut = tensor::InsertOp::create(builder, loc, rhs, chain, ivs); } // Generates regular insertion chain. env.updateInsertionChain(sparseOut); @@ -474,25 +474,25 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value fval = constantI1(builder, loc, false); Value tval = constantI1(builder, loc, true); // If statement. - Value isFilled = builder.create(loc, filled, index); - Value cond = builder.create(loc, arith::CmpIPredicate::eq, + Value isFilled = memref::LoadOp::create(builder, loc, filled, index); + Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, isFilled, fval); - scf::IfOp ifOp = builder.create(loc, builder.getIndexType(), cond, + scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIndexType(), cond, /*else=*/true); // True branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - builder.create(loc, tval, filled, index); - builder.create(loc, index, added, count); + memref::StoreOp::create(builder, loc, tval, filled, index); + memref::StoreOp::create(builder, loc, index, added, count); Value one = constantIndex(builder, loc, 1); - Value add = builder.create(loc, count, one); - builder.create(loc, add); + Value add = arith::AddIOp::create(builder, loc, count, one); + scf::YieldOp::create(builder, loc, add); // False branch. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, count); + scf::YieldOp::create(builder, loc, count); builder.setInsertionPointAfter(ifOp); // Value assignment. env.updateExpandCount(ifOp.getResult(0)); - builder.create(loc, rhs, values, index); + memref::StoreOp::create(builder, loc, rhs, values, index); } /// Generates a load on a dense or sparse tensor. @@ -522,9 +522,9 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { if (llvm::isa(ptr.getType())) { assert(env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator); - return builder.create(loc, ptr, llvm::getSingleElement(args)); + return ExtractValOp::create(builder, loc, ptr, llvm::getSingleElement(args)); } - return builder.create(loc, ptr, args); + return memref::LoadOp::create(builder, loc, ptr, args); } /// Generates a store on a dense or sparse tensor. @@ -551,7 +551,7 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, if (!env.isSparseOutput(t)) { SmallVector args; Value ptr = genSubscript(env, builder, t, args); - builder.create(loc, rhs, ptr, args); + memref::StoreOp::create(builder, loc, rhs, ptr, args); return; } // Store during sparse insertion. @@ -562,7 +562,7 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, // Select operation insertion. Value chain = env.getInsertionChain(); scf::IfOp ifOp = - builder.create(loc, chain.getType(), rhs, /*else=*/true); + scf::IfOp::create(builder, loc, chain.getType(), rhs, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Existing value was preserved to be used here. assert(env.exp(exp).val); @@ -571,10 +571,10 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, env.merger().clearExprValue(exp); // Yield modified insertion chain along true branch. Value mchain = env.getInsertionChain(); - builder.create(op.getLoc(), mchain); + scf::YieldOp::create(builder, op.getLoc(), mchain); // Yield original insertion chain along false branch. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, chain); + scf::YieldOp::create(builder, loc, chain); // Done with if statement. env.updateInsertionChain(ifOp->getResult(0)); builder.setInsertionPointAfter(ifOp); @@ -603,7 +603,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! SmallVector args; Value ptr = genSubscript(env, rewriter, t, args); - return rewriter.create(op.getLoc(), ptr, args); + return memref::LoadOp::create(rewriter, op.getLoc(), ptr, args); } } else if (Operation *def = e.getDefiningOp()) { // Handle index computation. @@ -774,7 +774,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, Type t2 = MemRefType::get(dynShape, builder.getI1Type()); Type t3 = MemRefType::get(dynShape, builder.getIndexType()); Type t4 = builder.getIndexType(); - auto r = builder.create(loc, TypeRange({t1, t2, t3, t4}), tensor); + auto r = ExpandOp::create(builder, loc, TypeRange({t1, t2, t3, t4}), tensor); assert(r.getNumResults() == 4); env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2), r.getResult(3)); @@ -787,7 +787,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, Value added = env.getExpandAdded(); Value count = env.getExpandCount(); Value chain = env.getInsertionChain(); - Value compress = builder.create(loc, values, filled, added, + Value compress = CompressOp::create(builder, loc, values, filled, added, count, chain, indices); env.updateInsertionChain(compress); env.endExpand(); @@ -895,7 +895,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, env.updateInsertionChain(ifOp->getResult(y++)); } assert(y == yields.size()); - builder.create(loc, yields); + scf::YieldOp::create(builder, loc, yields); builder.setInsertionPointAfter(ifOp); } } @@ -948,13 +948,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, assert(lvl.has_value()); const Value crd = env.emitter().getCoord(tid, *lvl); const Value lvar = env.getLoopVar(curr); - clause = builder.create(loc, arith::CmpIPredicate::eq, + clause = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, crd, lvar); } else { assert(lt.hasDenseSemantic() || isUndefLT(lt)); clause = constantI1(builder, loc, true); } - cond = cond ? builder.create(loc, cond, clause) : clause; + cond = cond ? arith::AndIOp::create(builder, loc, cond, clause) : clause; }); if (env.isReduc()) { types.push_back(env.getReduc().getType()); @@ -965,7 +965,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, types.push_back(builder.getIndexType()); if (env.getInsertionChain()) types.push_back(env.getInsertionChain().getType()); - scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); + scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); return ifOp; } @@ -993,7 +993,7 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, env.updateInsertionChain(insInput); } if (!operands.empty()) - builder.create(env.op().getLoc(), operands); + scf::YieldOp::create(builder, env.op().getLoc(), operands); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); } @@ -1307,7 +1307,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, genStmt(env, rewriter, ej, curr + 1); // TODO: handle yield values. assert(reduc.empty() && "Not Implemented"); - rewriter.create(env.op().getLoc()); + sparse_tensor::YieldOp::create(rewriter, env.op().getLoc()); return std::nullopt; }); // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp index c370d104e0985..79c8d6c171249 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp @@ -42,7 +42,7 @@ struct GuardSparseAlloc // operation that leaves the underlying storage in a proper state // before the tensor escapes across the method boundary. rewriter.setInsertionPointAfter(op); - auto load = rewriter.create(op.getLoc(), op.getResult(), true); + auto load = LoadOp::create(rewriter, op.getLoc(), op.getResult(), true); rewriter.replaceAllUsesExcept(op, load, load); return success(); } @@ -61,7 +61,7 @@ struct StageUnorderedSparseOps : public OpRewritePattern { // Deallocate tmpBuf. // TODO: Delegate to buffer deallocation pass in the future. if (succeeded(stageResult) && tmpBuf) - rewriter.create(loc, tmpBuf); + bufferization::DeallocTensorOp::create(rewriter, loc, tmpBuf); return stageResult; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index 1bd9563b3db07..7a735474b6d57 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -156,7 +156,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, // int <=> index if (isa(srcTp) || isa(dstTp)) - return builder.create(loc, dstTp, value); + return arith::IndexCastOp::create(builder, loc, dstTp, value); const auto srcIntTp = dyn_cast_or_null(srcTp); const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; @@ -169,19 +169,19 @@ Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, // Scalars can only be converted to 0-ranked tensors. assert(rtp.getRank() == 0); elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType()); - return builder.create(loc, rtp, elem); + return tensor::FromElementsOp::create(builder, loc, rtp, elem); } return sparse_tensor::genCast(builder, loc, elem, dstTp); } Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s) { - Value load = builder.create(loc, mem, s); + Value load = memref::LoadOp::create(builder, loc, mem, s); if (!isa(load.getType())) { if (load.getType().getIntOrFloatBitWidth() < 64) - load = builder.create(loc, builder.getI64Type(), load); + load = arith::ExtUIOp::create(builder, loc, builder.getI64Type(), load); load = - builder.create(loc, builder.getIndexType(), load); + arith::IndexCastOp::create(builder, loc, builder.getIndexType(), load); } return load; } @@ -206,13 +206,13 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, Type tp = v.getType(); Value zero = constantZero(builder, loc, tp); if (isa(tp)) - return builder.create(loc, arith::CmpFPredicate::UNE, v, + return arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::UNE, v, zero); if (tp.isIntOrIndex()) - return builder.create(loc, arith::CmpIPredicate::ne, v, + return arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, v, zero); if (isa(tp)) - return builder.create(loc, v, zero); + return complex::NotEqualOp::create(builder, loc, v, zero); llvm_unreachable("Non-numeric type"); } @@ -226,7 +226,7 @@ void mlir::sparse_tensor::genReshapeDstShape( for (const auto &map : llvm::enumerate(reassociation)) { auto dstDim = constantIndex(builder, loc, 1); for (unsigned i = start; i < start + map.value().size(); i++) { - dstDim = builder.create(loc, dstDim, srcShape[i]); + dstDim = arith::MulIOp::create(builder, loc, dstDim, srcShape[i]); } dstShape.push_back(dstDim); start = start + map.value().size(); @@ -260,7 +260,7 @@ void mlir::sparse_tensor::genReshapeDstShape( // Compute the dynamic dimension size. Value productVal = constantIndex(builder, loc, product); Value dynamicSize = - builder.create(loc, srcDim, productVal); + arith::DivUIOp::create(builder, loc, srcDim, productVal); dstShape.push_back(dynamicSize); } else { // The expanded dimension is statically known. @@ -289,7 +289,7 @@ void mlir::sparse_tensor::reshapeCvs( // Prepare strides information in dimension slice. Value linear = constantIndex(builder, loc, 1); for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear = builder.create(loc, linear, sizes[j]); + linear = arith::MulIOp::create(builder, loc, linear, sizes[j]); } // Start expansion. Value val; @@ -297,16 +297,16 @@ void mlir::sparse_tensor::reshapeCvs( val = srcCvs[i]; // Iterate over dimension slice. for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear = builder.create(loc, linear, sizes[j]); + linear = arith::DivUIOp::create(builder, loc, linear, sizes[j]); if (isCollapse) { - const Value mul = builder.create(loc, srcCvs[j], linear); - val = val ? builder.create(loc, val, mul) : mul; + const Value mul = arith::MulIOp::create(builder, loc, srcCvs[j], linear); + val = val ? arith::AddIOp::create(builder, loc, val, mul) : mul; } else { const Value old = val; - val = builder.create(loc, val, linear); + val = arith::DivUIOp::create(builder, loc, val, linear); assert(dstCvs.size() == j); dstCvs.push_back(val); - val = builder.create(loc, old, linear); + val = arith::RemUIOp::create(builder, loc, old, linear); } } // Finalize collapse. @@ -329,7 +329,7 @@ FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, auto func = module.lookupSymbol(result.getAttr()); if (!func) { OpBuilder moduleBuilder(module.getBodyRegion()); - func = moduleBuilder.create( + func = func::FuncOp::create(moduleBuilder, module.getLoc(), name, FunctionType::get(context, operands.getTypes(), resultType)); func.setPrivate(); @@ -346,7 +346,7 @@ func::CallOp mlir::sparse_tensor::createFuncCall( auto module = builder.getBlock()->getParentOp()->getParentOfType(); FlatSymbolRefAttr fn = getFunc(module, name, resultType, operands, emitCInterface); - return builder.create(loc, resultType, fn, operands); + return func::CallOp::create(builder, loc, resultType, fn, operands); } Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) { @@ -361,7 +361,7 @@ Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp, bool staticShape) { if (staticShape) { auto memTp = MemRefType::get({sz}, tp); - return builder.create(loc, memTp); + return memref::AllocaOp::create(builder, loc, memTp); } return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); } @@ -369,12 +369,12 @@ Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); - return builder.create(loc, memTp, ValueRange{sz}); + return memref::AllocaOp::create(builder, loc, memTp, ValueRange{sz}); } Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, Type tp) { - return builder.create(loc, MemRefType::get({}, tp)); + return memref::AllocaOp::create(builder, loc, MemRefType::get({}, tp)); } Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, @@ -384,7 +384,7 @@ Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, Value buffer = genAlloca(builder, loc, sz, values[0].getType()); for (unsigned i = 0; i < sz; i++) { Value idx = constantIndex(builder, loc, i); - builder.create(loc, values[i], buffer, idx); + memref::StoreOp::create(builder, loc, values[i], buffer, idx); } return buffer; } @@ -400,15 +400,15 @@ Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, if (shape[i] == ShapedType::kDynamic) dynamicSizes.push_back(sizes[i]); } - Value mem = builder.create(loc, memTp, dynamicSizes); + Value mem = memref::AllocOp::create(builder, loc, memTp, dynamicSizes); Value zero = constantZero(builder, loc, elemTp); - builder.create(loc, ValueRange{zero}, ValueRange{mem}); + linalg::FillOp::create(builder, loc, ValueRange{zero}, ValueRange{mem}); return mem; } void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) { - builder.create(loc, buffer); + memref::DeallocOp::create(builder, loc, buffer); } void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, @@ -486,17 +486,17 @@ void sparse_tensor::foreachInSparseConstant( cvs.clear(); for (Dimension d = 0; d < dimRank; d++) { auto crd = elems[i].first[d].getInt(); - cvs.push_back(builder.create(loc, crd)); + cvs.push_back(arith::ConstantIndexOp::create(builder, loc, crd)); } // Remap value. Value val; if (isa(attr.getElementType())) { auto valAttr = cast(elems[i].second); - val = builder.create(loc, attr.getElementType(), + val = complex::ConstantOp::create(builder, loc, attr.getElementType(), valAttr); } else { auto valAttr = cast(elems[i].second); - val = builder.create(loc, valAttr); + val = arith::ConstantOp::create(builder, loc, valAttr); } assert(val); callback(cvs, val); @@ -516,10 +516,10 @@ SmallVector sparse_tensor::loadAll(OpBuilder &builder, Location loc, SmallVector vs; vs.reserve(size); for (unsigned i = 0; i < size; i++) { - Value v = builder.create(loc, mem, + Value v = memref::LoadOp::create(builder, loc, mem, constantIndex(builder, loc, i)); if (i == offsetIdx && offsetVal) - v = builder.create(loc, v, offsetVal); + v = arith::AddIOp::create(builder, loc, v, offsetVal); vs.push_back(v); } return vs; @@ -538,9 +538,9 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, for (const auto &v : llvm::enumerate(vs)) { const Value w = (offsetIdx == v.index() && offsetVal) - ? builder.create(loc, v.value(), offsetVal) + ? arith::AddIOp::create(builder, loc, v.value(), offsetVal) : v.value(); - builder.create(loc, w, mem, + memref::StoreOp::create(builder, loc, w, mem, constantIndex(builder, loc, v.index())); } } @@ -550,7 +550,7 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { auto tTp = llvm::cast(tensor.getType()); auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); return cast>( - builder.create(loc, mTp, tensor).getResult()); + bufferization::ToBufferOp::create(builder, loc, mTp, tensor).getResult()); } Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, @@ -560,7 +560,7 @@ Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, std::optional offset = enc.getStaticDimSliceOffset(dim); if (offset.has_value()) return constantIndex(builder, loc, *offset); - return builder.create(loc, tensor, APInt(64, dim)); + return ToSliceOffsetOp::create(builder, loc, tensor, APInt(64, dim)); } Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, @@ -570,7 +570,7 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, std::optional stride = enc.getStaticDimSliceStride(dim); if (stride.has_value()) return constantIndex(builder, loc, *stride); - return builder.create(loc, tensor, APInt(64, dim)); + return ToSliceStrideOp::create(builder, loc, tensor, APInt(64, dim)); } Value sparse_tensor::genReader(OpBuilder &builder, Location loc, @@ -612,7 +612,7 @@ Value sparse_tensor::genReader(OpBuilder &builder, Location loc, // subsequent clients need the values (DCE will remove unused). for (Dimension d = 0; d < dimRank; d++) { if (stt.isDynamicDim(d)) - dimSizesValues[d] = builder.create( + dimSizesValues[d] = memref::LoadOp::create(builder, loc, dimSizesBuffer, constantIndex(builder, loc, d)); } } @@ -689,7 +689,7 @@ Value sparse_tensor::genMapBuffers( if (cm == 0) { lvlSz = dimSizesValues[d]; if (cf != 0) - lvlSz = builder.create(loc, lvlSz, + lvlSz = arith::DivUIOp::create(builder, loc, lvlSz, constantIndex(builder, loc, cf)); } else { lvlSz = constantIndex(builder, loc, cm); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h index dc017e6baa6dc..400be54aeb86f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h @@ -97,7 +97,7 @@ class FuncCallOrInlineGenerator { // Create the function if not already exist. OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(getParentOpOf(builder)); - func = builder.create( + func = func::FuncOp::create(builder, loc, funcName, FunctionType::get(context, params.getTypes(), retTypes)); func.setPrivate(); @@ -108,10 +108,10 @@ class FuncCallOrInlineGenerator { // Delegates to user to generate the actually implementation. SmallVector result = genImplementation(retTypes, args, builder, loc); - builder.create(loc, result); + func::ReturnOp::create(builder, loc, result); } // Returns the CallOp result. - func::CallOp call = builder.create(loc, func, params); + func::CallOp call = func::CallOp::create(builder, loc, func, params); return call.getResults(); } @@ -310,9 +310,9 @@ inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { if (auto ctp = dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto zeroa = builder.getArrayAttr({zeroe, zeroe}); - return builder.create(loc, tp, zeroa); + return complex::ConstantOp::create(builder, loc, tp, zeroa); } - return builder.create(loc, tp, builder.getZeroAttr(tp)); + return arith::ConstantOp::create(builder, loc, tp, builder.getZeroAttr(tp)); } /// Generates a 1-valued constant of the given type. This supports all @@ -322,39 +322,39 @@ inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto onee = getOneAttr(builder, ctp.getElementType()); auto zeroa = builder.getArrayAttr({onee, zeroe}); - return builder.create(loc, tp, zeroa); + return complex::ConstantOp::create(builder, loc, tp, zeroa); } - return builder.create(loc, tp, getOneAttr(builder, tp)); + return arith::ConstantOp::create(builder, loc, tp, getOneAttr(builder, tp)); } /// Generates a constant of `index` type. inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { - return builder.create(loc, i); + return arith::ConstantIndexOp::create(builder, loc, i); } /// Generates a constant of `i64` type. inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) { - return builder.create(loc, i, 64); + return arith::ConstantIntOp::create(builder, loc, i, 64); } /// Generates a constant of `i32` type. inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { - return builder.create(loc, i, 32); + return arith::ConstantIntOp::create(builder, loc, i, 32); } /// Generates a constant of `i16` type. inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { - return builder.create(loc, i, 16); + return arith::ConstantIntOp::create(builder, loc, i, 16); } /// Generates a constant of `i8` type. inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { - return builder.create(loc, i, 8); + return arith::ConstantIntOp::create(builder, loc, i, 8); } /// Generates a constant of `i1` type. inline Value constantI1(OpBuilder &builder, Location loc, bool b) { - return builder.create(loc, b, 1); + return arith::ConstantIntOp::create(builder, loc, b, 1); } /// Generates a constant of the given `Action`. @@ -400,12 +400,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) { if (auto complexAttr = dyn_cast(attr)) { Type tp = cast(complexAttr.getType()).getElementType(); - return builder.create( + return complex::ConstantOp::create(builder, loc, complexAttr.getType(), builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()), FloatAttr::get(tp, complexAttr.getImag())})); } - return builder.create(loc, cast(attr)); + return arith::ConstantOp::create(builder, loc, cast(attr)); } // TODO: is this at the right place? diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index 9e41c8ec19bcd..6d795f91b1ab2 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -27,18 +27,18 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// #define CMPI(p, l, r) \ - (builder.create(loc, arith::CmpIPredicate::p, (l), (r)) \ + (arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::p, (l), (r)) \ .getResult()) #define C_IDX(v) (constantIndex(builder, loc, (v))) -#define YIELD(vs) (builder.create(loc, (vs))) -#define ADDI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define ANDI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define SUBI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define MULI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define REMUI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define DIVUI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define SELECT(c, l, r) (builder.create(loc, (c), (l), (r))) +#define YIELD(vs) (scf::YieldOp::create(builder, loc, (vs))) +#define ADDI(lhs, rhs) (arith::AddIOp::create(builder, loc, (lhs), (rhs))) +#define ANDI(lhs, rhs) (arith::AndIOp::create(builder, loc, (lhs), (rhs))) +#define SUBI(lhs, rhs) (arith::SubIOp::create(builder, loc, (lhs), (rhs))) +#define MULI(lhs, rhs) (arith::MulIOp::create(builder, loc, (lhs), (rhs))) +#define REMUI(lhs, rhs) (arith::RemUIOp::create(builder, loc, (lhs), (rhs))) +#define DIVUI(lhs, rhs) (arith::DivUIOp::create(builder, loc, (lhs), (rhs))) +#define SELECT(c, l, r) (arith::SelectOp::create(builder, loc, (c), (l), (r))) //===----------------------------------------------------------------------===// // Debugging utils @@ -47,7 +47,7 @@ using namespace mlir::sparse_tensor; #ifndef NDEBUG LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref) { - memref = builder.create( + memref = memref::CastOp::create(builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, ValueRange{memref}, EmitCInterface::On); @@ -263,7 +263,7 @@ void LoopEmitter::initializeLoopEmit( denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp); Value denseVal = - builder.create(loc, denseTp, tensor); + bufferization::ToBufferOp::create(builder, loc, denseTp, tensor); // Dense outputs need special handling. if (isOutput && updater) denseVal = updater(builder, loc, denseVal, tensor); @@ -273,7 +273,7 @@ void LoopEmitter::initializeLoopEmit( // Annotated sparse tensors. // We also need the value buffer for all-dense annotated "sparse" // tensors. - valBuffer[t] = builder.create(loc, tensor); + valBuffer[t] = ToValuesOp::create(builder, loc, tensor); } } @@ -481,7 +481,7 @@ std::pair LoopEmitter::emitForLoopOverTensorAtLvl( Value iv; if (isParallel) { scf::ParallelOp parOp = - builder.create(loc, lo, hi, step, reduc); + scf::ParallelOp::create(builder, loc, lo, hi, step, reduc); builder.setInsertionPointToStart(parOp.getBody()); assert(parOp.getNumReductions() == reduc.size()); iv = parOp.getInductionVars()[0]; @@ -497,7 +497,7 @@ std::pair LoopEmitter::emitForLoopOverTensorAtLvl( reduc[i] = parOp.getInitVals()[i]; loop = parOp; } else { - scf::ForOp forOp = builder.create(loc, lo, hi, step, reduc); + scf::ForOp forOp = scf::ForOp::create(builder, loc, lo, hi, step, reduc); builder.setInsertionPointToStart(forOp.getBody()); iv = forOp.getInductionVar(); @@ -605,11 +605,11 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( // Extract and iterate over the iteration space. ExtractIterSpaceOp extractSpaceOp = - lvl == 0 ? builder.create(loc, t) - : builder.create( + lvl == 0 ? ExtractIterSpaceOp::create(builder, loc, t) + : ExtractIterSpaceOp::create(builder, loc, t, spIterVals[tid][lvl - 1], lvl); - IterateOp iterOp = builder.create( + IterateOp iterOp = IterateOp::create(builder, loc, extractSpaceOp.getExtractedSpace(), reduc); spIterVals[tid][lvl] = iterOp.getIterator(); @@ -627,12 +627,12 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { Value t = tensors[tid]; ExtractIterSpaceOp extractSpaceOp = - lvl == 0 ? builder.create(loc, t) - : builder.create( + lvl == 0 ? ExtractIterSpaceOp::create(builder, loc, t) + : ExtractIterSpaceOp::create(builder, loc, t, spIterVals[tid][lvl - 1], lvl); spaces.push_back(extractSpaceOp.getExtractedSpace()); } - auto coIterOp = builder.create(loc, spaces, reduc, numCases); + auto coIterOp = CoIterateOp::create(builder, loc, spaces, reduc, numCases); // The CoIterationOp does not have insertion block nor induction variable. // TODO: the `struct LoopInfo` should be simplied after full migration. loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr, @@ -730,7 +730,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, if (emitStrategy == SparseEmitStrategy::kSparseIterator) { auto iterateOp = llvm::cast(loopInfo.loop); assert(reduc.size() == iterateOp.getNumResults()); - rewriter.create(loc, reduc); + sparse_tensor::YieldOp::create(rewriter, loc, reduc); // Exit the loop. rewriter.setInsertionPointAfter(iterateOp); // In-place update reduction variables. @@ -740,7 +740,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, if (auto forOp = llvm::dyn_cast(loopInfo.loop)) { if (!reduc.empty()) { assert(reduc.size() == forOp.getNumResults()); - rewriter.create(loc, reduc); + scf::YieldOp::create(rewriter, loc, reduc); } // Exit the loop. rewriter.setInsertionPointAfter(forOp); @@ -779,7 +779,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, #endif // NDEBUG rewriter.setInsertionPointAfter(redExp); - auto redOp = rewriter.create(loc, curVal); + auto redOp = scf::ReduceOp::create(rewriter, loc, curVal); // Attach to the reduction op. Block *redBlock = &redOp.getReductions().front().front(); rewriter.setInsertionPointToEnd(redBlock); @@ -791,7 +791,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, // Erases the out-dated reduction expression. rewriter.eraseOp(redExp); rewriter.setInsertionPointToEnd(redBlock); - rewriter.create(loc, newRed->getResult(0)); + scf::ReduceReturnOp::create(rewriter, loc, newRed->getResult(0)); } rewriter.setInsertionPointAfter(parOp); // In-place update reduction variables. @@ -865,7 +865,7 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, if (emitStrategy == SparseEmitStrategy::kSparseIterator) { Operation *p = loopInfo.loop; if (isa(p)) - rewriter.create(loc, reduc); + sparse_tensor::YieldOp::create(rewriter, loc, reduc); // Exit the loop. rewriter.setInsertionPointAfter(p); @@ -931,7 +931,7 @@ std::pair sparse_tensor::genCoIteration( // Ensures all operands are valid. assert(!llvm::is_contained(ivs, nullptr)); TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = builder.create(loc, types, ivs); + auto whileOp = scf::WhileOp::create(builder, loc, types, ivs); SmallVector locs(types.size(), loc); Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); @@ -950,7 +950,7 @@ std::pair sparse_tensor::genCoIteration( // The remaining block arguments are user-provided reduction values and an // optional universal index. Make sure their sizes match. assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0)); - builder.create(loc, whileCond, before->getArguments()); + scf::ConditionOp::create(builder, loc, whileCond, before->getArguments()); // Generates loop body. builder.setInsertionPointToStart(after); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index cf99117065c5f..3b6976f8b7341 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -79,14 +79,14 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt) { - return builder.create( + return StorageSpecifierInitOp::create(builder, loc, StorageSpecifierType::get(stt.getEncoding())); } Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional lvl) { - return builder.create( + return GetStorageSpecifierOp::create(builder, loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl)); } @@ -96,7 +96,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, std::optional lvl) { // TODO: make `v` have type `TypedValue` instead. assert(v.getType().isIndex()); - specifier = builder.create( + specifier = SetStorageSpecifierOp::create(builder, loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v); } @@ -112,8 +112,8 @@ Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView( Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart); Value size = getCrdMemSize(builder, loc, cooStart); - size = builder.create(loc, size, stride); - return builder.create( + size = arith::DivUIOp::create(builder, loc, size, stride); + return memref::SubViewOp::create(builder, loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart), /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)}, /*size=*/ValueRange{size}, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h index 869c7864d7535..45d142a807c36 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h @@ -231,7 +231,7 @@ class MutSparseTensorDescriptor /// Packs the given values as a "tuple" value. inline Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { - return builder.create(loc, TypeRange(tp), values) + return UnrealizedConversionCastOp::create(builder, loc, TypeRange(tp), values) .getResult(0); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index aad5e97ed14ab..212ee99647167 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -22,23 +22,23 @@ using ValueTuple = std::tuple; // File local helper functions/macros. //===----------------------------------------------------------------------===// #define CMPI(p, lhs, rhs) \ - (b.create(l, arith::CmpIPredicate::p, (lhs), (rhs)) \ + (arith::CmpIOp::create(b, l, arith::CmpIPredicate::p, (lhs), (rhs)) \ .getResult()) #define C_FALSE (constantI1(b, l, false)) #define C_TRUE (constantI1(b, l, true)) #define C_IDX(v) (constantIndex(b, l, (v))) -#define YIELD(vs) (b.create(l, (vs))) -#define ADDI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define ORI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define ANDI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define SUBI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define MULI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define MINUI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define REMUI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) -#define DIVUI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define YIELD(vs) (scf::YieldOp::create(b, l, (vs))) +#define ADDI(lhs, rhs) (arith::AddIOp::create(b, l, (lhs), (rhs)).getResult()) +#define ORI(lhs, rhs) (arith::OrIOp::create(b, l, (lhs), (rhs)).getResult()) +#define ANDI(lhs, rhs) (arith::AndIOp::create(b, l, (lhs), (rhs)).getResult()) +#define SUBI(lhs, rhs) (arith::SubIOp::create(b, l, (lhs), (rhs)).getResult()) +#define MULI(lhs, rhs) (arith::MulIOp::create(b, l, (lhs), (rhs)).getResult()) +#define MINUI(lhs, rhs) (arith::MinUIOp::create(b, l, (lhs), (rhs)).getResult()) +#define REMUI(lhs, rhs) (arith::RemUIOp::create(b, l, (lhs), (rhs)).getResult()) +#define DIVUI(lhs, rhs) (arith::DivUIOp::create(b, l, (lhs), (rhs)).getResult()) #define SELECT(c, lhs, rhs) \ - (b.create(l, (c), (lhs), (rhs)).getResult()) + (arith::SelectOp::create(b, l, (c), (lhs), (rhs)).getResult()) //===----------------------------------------------------------------------===// // SparseTensorLevel derived classes. @@ -150,19 +150,19 @@ class CompressedLevel : public SparseLevel { return loadRange(); SmallVector types{b.getIndexType(), b.getIndexType()}; - scf::IfOp posRangeIf = b.create(l, types, inPadZone, true); + scf::IfOp posRangeIf = scf::IfOp::create(b, l, types, inPadZone, true); // True branch, returns a "fake" empty range [0, 0) if parent // iterator is in pad zone. b.setInsertionPointToStart(posRangeIf.thenBlock()); SmallVector emptyRange{C_IDX(0), C_IDX(0)}; - b.create(l, emptyRange); + scf::YieldOp::create(b, l, emptyRange); // False branch, returns the actual range. b.setInsertionPointToStart(posRangeIf.elseBlock()); auto [pLo, pHi] = loadRange(); SmallVector loadedRange{pLo, pHi}; - b.create(l, loadedRange); + scf::YieldOp::create(b, l, loadedRange); b.setInsertionPointAfter(posRangeIf); ValueRange posRange = posRangeIf.getResults(); @@ -248,7 +248,7 @@ static scf::ValueVector genWhenInBound( llvm::function_ref builder) { TypeRange ifRetTypes = elseRet.getTypes(); - auto ifOp = b.create(l, ifRetTypes, it.genNotEnd(b, l), true); + auto ifOp = scf::IfOp::create(b, l, ifRetTypes, it.genNotEnd(b, l), true); b.setInsertionPointToStart(ifOp.thenBlock()); Value crd = it.deref(b, l); @@ -732,7 +732,7 @@ class NonEmptySubSectIterator : public SparseIterator { // [itVal0, itVal1, ..., pNx0], // ...] Value allocSubSectPosBuf(OpBuilder &b, Location l) { - return b.create( + return memref::AllocaOp::create(b, l, MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()), maxTupleCnt); @@ -740,12 +740,12 @@ class NonEmptySubSectIterator : public SparseIterator { void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId, Value start) const { - b.create(l, start, subSectPosBuf, + memref::StoreOp::create(b, l, start, subSectPosBuf, ValueRange{tupleId, C_IDX(tupleSz)}); } Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const { - return b.create(l, subSectPosBuf, + return memref::LoadOp::create(b, l, subSectPosBuf, ValueRange{tupleId, C_IDX(tupleSz)}); } @@ -753,7 +753,7 @@ class NonEmptySubSectIterator : public SparseIterator { ValueRange itVals) const { assert(itVals.size() == tupleSz); for (unsigned i = 0; i < tupleSz; i++) { - b.create(l, itVals[i], subSectPosBuf, + memref::StoreOp::create(b, l, itVals[i], subSectPosBuf, ValueRange{tupleId, C_IDX(i)}); } } @@ -762,7 +762,7 @@ class NonEmptySubSectIterator : public SparseIterator { Value tupleId) const { SmallVector ret; for (unsigned i = 0; i < tupleSz; i++) { - Value v = b.create(l, subSectPosBuf, + Value v = memref::LoadOp::create(b, l, subSectPosBuf, ValueRange{tupleId, C_IDX(i)}); ret.push_back(v); } @@ -1043,7 +1043,7 @@ ValueRange SparseIterator::forward(OpBuilder &b, Location l) { } ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) { - auto ifOp = b.create(l, getCursor().getTypes(), cond, true); + auto ifOp = scf::IfOp::create(b, l, getCursor().getTypes(), cond, true); // Generate else branch first, otherwise iterator values will be updated by // `forward()`. b.setInsertionPointToStart(ifOp.elseBlock()); @@ -1058,12 +1058,12 @@ ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) { } Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) { - auto whileOp = b.create( + auto whileOp = scf::WhileOp::create(b, l, pos.getType(), pos, /*beforeBuilder=*/ [this, pos](OpBuilder &b, Location l, ValueRange ivs) { Value inBound = CMPI(ult, ivs.front(), posHi); - auto ifInBound = b.create(l, b.getI1Type(), inBound, true); + auto ifInBound = scf::IfOp::create(b, l, b.getI1Type(), inBound, true); { OpBuilder::InsertionGuard guard(b); // If in bound, load the next coordinates and check duplication. @@ -1076,7 +1076,7 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) { b.setInsertionPointToStart(ifInBound.elseBlock()); YIELD(constantI1(b, l, false)); } - b.create(l, ifInBound.getResults()[0], ivs); + scf::ConditionOp::create(b, l, ifInBound.getResults()[0], ivs); }, /*afterBuilder=*/ [](OpBuilder &b, Location l, ValueRange ivs) { @@ -1137,7 +1137,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) { SmallVector whileArgs(getCursor().begin(), getCursor().end()); whileArgs.push_back(isFirst); - auto whileOp = b.create( + auto whileOp = scf::WhileOp::create(b, l, ValueRange(whileArgs).getTypes(), whileArgs, /*beforeBuilder=*/ [this](OpBuilder &b, Location l, ValueRange ivs) { @@ -1154,7 +1154,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) { ret = ORI(ret, llvm::getSingleElement(isFirst)); return {ret}; }); - b.create(l, cont.front(), ivs); + scf::ConditionOp::create(b, l, cont.front(), ivs); }, /*afterBuilder=*/ [this](OpBuilder &b, Location l, ValueRange ivs) { @@ -1219,7 +1219,7 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree( SmallVector iterArgs; iterArgs.push_back(C_IDX(0)); iterArgs.append(reduc.begin(), reduc.end()); - auto forEachLeaf = b.create( + auto forEachLeaf = scf::ForOp::create(b, l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs, [&helper, &builder](OpBuilder &b, Location l, Value tupleId, ValueRange iterArgs) { @@ -1235,12 +1235,12 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree( SmallVector whileArgs(helper.wrap.getCursor()); whileArgs.append(iterArgs.begin(), iterArgs.end()); - auto whileOp = b.create( + auto whileOp = scf::WhileOp::create(b, l, ValueRange(whileArgs).getTypes(), whileArgs, /*beforeBuilder=*/ [&helper](OpBuilder &b, Location l, ValueRange ivs) { helper.wrap.linkNewScope(ivs); - b.create(l, helper.genNotEnd(b, l), ivs); + scf::ConditionOp::create(b, l, helper.genNotEnd(b, l), ivs); }, /*afterBuilder=*/ [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) { @@ -1267,7 +1267,7 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree( ValueRange reduc) { assert(!parent || parent->lvl + 1 == lvl); delegate->genInit(b, l, parent); - auto forOp = b.create( + auto forOp = scf::ForOp::create(b, l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc, [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) { helper.locate(b, l, crd); @@ -1411,7 +1411,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) { // if (offset + size > parents.size) // isNonEmpty = false; Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff()); - auto ifOp = b.create(l, getCursor().getTypes(), fastPathP, true); + auto ifOp = scf::IfOp::create(b, l, getCursor().getTypes(), fastPathP, true); { OpBuilder::InsertionGuard guard(b); // Take the fast path @@ -1448,7 +1448,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) { Value isMin = CMPI(eq, crd, getMinCrd()); delegate->forwardIf(b, l, isMin); // Update the forwarded iterator values if needed. - auto ifIsMin = b.create(l, isMin, false); + auto ifIsMin = scf::IfOp::create(b, l, isMin, false); b.setInsertionPointToStart(&ifIsMin.getThenRegion().front()); storeCursorVals(b, l, tupleId, delegate->serialize()); b.setInsertionPointAfter(ifIsMin); @@ -1458,7 +1458,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) { return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs, [nxMin](OpBuilder &b, Location l, Value crd) -> scf::ValueVector { - Value nx = b.create( + Value nx = arith::MinUIOp::create(b, l, crd, nxMin); return {nx, C_TRUE}; }); @@ -1480,7 +1480,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) { // We should at least forward the offset by one. Value minAbsOff = ADDI(getAbsOff(), c1); - nxAbsOff = b.create(l, minAbsOff, nxAbsOff); + nxAbsOff = arith::MaxUIOp::create(b, l, minAbsOff, nxAbsOff); seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}); // The coordinate should not exceeds the space upper bound. @@ -1581,16 +1581,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, auto stt = getSparseTensorType(t); LevelType lt = stt.getLvlType(lvl); - Value sz = stt.hasEncoding() ? b.create(l, t, lvl).getResult() - : b.create(l, t, lvl).getResult(); + Value sz = stt.hasEncoding() ? LvlOp::create(b, l, t, lvl).getResult() + : tensor::DimOp::create(b, l, t, lvl).getResult(); SmallVector buffers; if (lt.isWithPosLT()) { - Value pos = b.create(l, t, lvl); + Value pos = ToPositionsOp::create(b, l, t, lvl); buffers.push_back(pos); } if (lt.isWithCrdLT()) { - Value pos = b.create(l, t, lvl); + Value pos = ToCoordinatesOp::create(b, l, t, lvl); buffers.push_back(pos); } return makeSparseTensorLevel(lt, sz, buffers, tid, lvl); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 0258f797143cb..fb68620c378e5 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1563,7 +1563,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, Block &clonedBlock = tmpRegion.front(); YieldOp clonedYield = cast(clonedBlock.getTerminator()); // Merge cloned block and return yield value. - Operation *placeholder = rewriter.create(loc, 0); + Operation *placeholder = arith::ConstantIndexOp::create(rewriter, loc, 0); rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); Value val = clonedYield.getSingleResult(); rewriter.eraseOp(clonedYield); @@ -1603,16 +1603,16 @@ static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, Attribute attr) { Type tp = v0.getType(); auto zero = - rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + arith::ConstantOp::create(rewriter, loc, tp, rewriter.getZeroAttr(tp)); Value cmp; if (isa(tp)) { auto pred = llvm::cast(attr); - cmp = rewriter.create(loc, pred, v0, zero); + cmp = arith::CmpFOp::create(rewriter, loc, pred, v0, zero); } else { auto pred = llvm::cast(attr); - cmp = rewriter.create(loc, pred, v0, zero); + cmp = arith::CmpIOp::create(rewriter, loc, pred, v0, zero); } - return rewriter.create(loc, cmp, v0, zero); + return arith::SelectOp::create(rewriter, loc, cmp, v0, zero); } Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, @@ -1627,128 +1627,128 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, llvm_unreachable("unexpected non-op"); // Unary operations. case TensorExp::Kind::kAbsF: - return rewriter.create(loc, v0); + return math::AbsFOp::create(rewriter, loc, v0); case TensorExp::Kind::kAbsC: { auto type = cast(v0.getType()); auto eltType = cast(type.getElementType()); - return rewriter.create(loc, eltType, v0); + return complex::AbsOp::create(rewriter, loc, eltType, v0); } case TensorExp::Kind::kAbsI: - return rewriter.create(loc, v0); + return math::AbsIOp::create(rewriter, loc, v0); case TensorExp::Kind::kCeilF: - return rewriter.create(loc, v0); + return math::CeilOp::create(rewriter, loc, v0); case TensorExp::Kind::kFloorF: - return rewriter.create(loc, v0); + return math::FloorOp::create(rewriter, loc, v0); case TensorExp::Kind::kSqrtF: - return rewriter.create(loc, v0); + return math::SqrtOp::create(rewriter, loc, v0); case TensorExp::Kind::kSqrtC: - return rewriter.create(loc, v0); + return complex::SqrtOp::create(rewriter, loc, v0); case TensorExp::Kind::kExpm1F: - return rewriter.create(loc, v0); + return math::ExpM1Op::create(rewriter, loc, v0); case TensorExp::Kind::kExpm1C: - return rewriter.create(loc, v0); + return complex::Expm1Op::create(rewriter, loc, v0); case TensorExp::Kind::kLog1pF: - return rewriter.create(loc, v0); + return math::Log1pOp::create(rewriter, loc, v0); case TensorExp::Kind::kLog1pC: - return rewriter.create(loc, v0); + return complex::Log1pOp::create(rewriter, loc, v0); case TensorExp::Kind::kRelu: return buildRelu(rewriter, loc, v0, expr.attr); case TensorExp::Kind::kSinF: - return rewriter.create(loc, v0); + return math::SinOp::create(rewriter, loc, v0); case TensorExp::Kind::kSinC: - return rewriter.create(loc, v0); + return complex::SinOp::create(rewriter, loc, v0); case TensorExp::Kind::kTanhF: - return rewriter.create(loc, v0); + return math::TanhOp::create(rewriter, loc, v0); case TensorExp::Kind::kTanhC: - return rewriter.create(loc, v0); + return complex::TanhOp::create(rewriter, loc, v0); case TensorExp::Kind::kNegF: - return rewriter.create(loc, v0); + return arith::NegFOp::create(rewriter, loc, v0); case TensorExp::Kind::kNegC: - return rewriter.create(loc, v0); + return complex::NegOp::create(rewriter, loc, v0); case TensorExp::Kind::kNegI: // no negi in std - return rewriter.create( + return arith::SubIOp::create(rewriter, loc, - rewriter.create(loc, v0.getType(), + arith::ConstantOp::create(rewriter, loc, v0.getType(), rewriter.getZeroAttr(v0.getType())), v0); case TensorExp::Kind::kTruncF: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::TruncFOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kExtF: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::ExtFOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastFS: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::FPToSIOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastFU: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::FPToUIOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastSF: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::SIToFPOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastUF: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::UIToFPOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastS: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::ExtSIOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastU: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::ExtUIOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCastIdx: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::IndexCastOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kTruncI: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::TruncIOp::create(rewriter, loc, inferType(e, v0), v0); case TensorExp::Kind::kCIm: { auto type = cast(v0.getType()); auto eltType = cast(type.getElementType()); - return rewriter.create(loc, eltType, v0); + return complex::ImOp::create(rewriter, loc, eltType, v0); } case TensorExp::Kind::kCRe: { auto type = cast(v0.getType()); auto eltType = cast(type.getElementType()); - return rewriter.create(loc, eltType, v0); + return complex::ReOp::create(rewriter, loc, eltType, v0); } case TensorExp::Kind::kBitCast: - return rewriter.create(loc, inferType(e, v0), v0); + return arith::BitcastOp::create(rewriter, loc, inferType(e, v0), v0); // Binary operations. case TensorExp::Kind::kMulF: - return rewriter.create(loc, v0, v1); + return arith::MulFOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kMulC: - return rewriter.create(loc, v0, v1); + return complex::MulOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kMulI: - return rewriter.create(loc, v0, v1); + return arith::MulIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kDivF: - return rewriter.create(loc, v0, v1); + return arith::DivFOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kDivC: - return rewriter.create(loc, v0, v1); + return complex::DivOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kDivS: - return rewriter.create(loc, v0, v1); + return arith::DivSIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kDivU: - return rewriter.create(loc, v0, v1); + return arith::DivUIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kAddF: - return rewriter.create(loc, v0, v1); + return arith::AddFOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kAddC: - return rewriter.create(loc, v0, v1); + return complex::AddOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kAddI: - return rewriter.create(loc, v0, v1); + return arith::AddIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kSubF: - return rewriter.create(loc, v0, v1); + return arith::SubFOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kSubC: - return rewriter.create(loc, v0, v1); + return complex::SubOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kSubI: - return rewriter.create(loc, v0, v1); + return arith::SubIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kAndI: - return rewriter.create(loc, v0, v1); + return arith::AndIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kOrI: - return rewriter.create(loc, v0, v1); + return arith::OrIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kXorI: - return rewriter.create(loc, v0, v1); + return arith::XOrIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kShrS: - return rewriter.create(loc, v0, v1); + return arith::ShRSIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kShrU: - return rewriter.create(loc, v0, v1); + return arith::ShRUIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kShlI: - return rewriter.create(loc, v0, v1); + return arith::ShLIOp::create(rewriter, loc, v0, v1); case TensorExp::Kind::kCmpI: { auto predicate = llvm::cast(expr.attr); - return rewriter.create(loc, predicate, v0, v1); + return arith::CmpIOp::create(rewriter, loc, predicate, v0, v1); } case TensorExp::Kind::kCmpF: { auto predicate = llvm::cast(expr.attr); - return rewriter.create(loc, predicate, v0, v1); + return arith::CmpFOp::create(rewriter, loc, predicate, v0, v1); } case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp index fc93f1c1c9220..5e6e70987f6bd 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp @@ -78,11 +78,11 @@ struct CreatorOpShardingInterface if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) { if (!newSharding) { newSharding = - builder.create(op->getLoc(), resultShardings[0]); + ShardingOp::create(builder, op->getLoc(), resultShardings[0]); device = - builder.create(op->getLoc(), mesh) + mesh::ProcessMultiIndexOp::create(builder, op->getLoc(), mesh) .getResults(); - shapeForDevice = builder.create( + shapeForDevice = mesh::ShardShapeOp::create(builder, op->getLoc(), oldType.getShape(), spmdizedOperands, newSharding->getResult(0), device); } @@ -92,7 +92,7 @@ struct CreatorOpShardingInterface newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); } } - newOp = builder.create(op->getLoc(), shardType, newOperands); + newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands); spmdizationMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index a3e863254405c..b19a413d6d9fd 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -56,8 +57,8 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder, if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) return op; if (complex::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, - llvm::cast(value)); + return complex::ConstantOp::create(builder, loc, type, + llvm::cast(value)); return nullptr; } @@ -110,7 +111,7 @@ FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, // Create empty tensor. Value emptyTensor = - b.create(loc, mixedSizes, tensorType.getElementType()); + tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType()); return emptyTensor; } @@ -681,8 +682,8 @@ FailureOr> ConcatOp::decomposeOperation(OpBuilder &builder) { inputShapes.emplace_back(std::move(inputShape)); } - Value replacement = builder.create( - loc, outputShape, getType().getElementType()); + Value replacement = tensor::EmptyOp::create(builder, loc, outputShape, + getType().getElementType()); int64_t rank = getType().getRank(); OpFoldResult one = builder.getIndexAttr(1); @@ -690,12 +691,12 @@ FailureOr> ConcatOp::decomposeOperation(OpBuilder &builder) { SmallVector offsets(rank, zero); for (auto [index, input] : llvm::enumerate(getInputs())) { offsets[concatDim] = concatOffsets[index]; - auto insertSlice = builder.create( - loc, input, replacement, offsets, inputShapes[index], strides); + auto insertSlice = tensor::InsertSliceOp::create( + builder, loc, input, replacement, offsets, inputShapes[index], strides); replacement = insertSlice.getResult(); } if (replacement.getType() != getType()) { - replacement = builder.create(loc, getType(), replacement); + replacement = tensor::CastOp::create(builder, loc, getType(), replacement); } return SmallVector{replacement}; } @@ -726,7 +727,7 @@ ConcatOp::reifyResultShapes(OpBuilder &builder, builder.getIndexAttr(inferredResultType.getDimSize(i))); } else { reifiedReturnShapes[0][i] = - builder.create(init.getLoc(), init, i).getResult(); + tensor::DimOp::create(builder, init.getLoc(), init, i).getResult(); } } @@ -826,8 +827,8 @@ struct InferConcatOperandTypes : public OpRewritePattern { // Use refined operand type and create cast from original operand. auto castOp = - rewriter.create(concatOp->getLoc(), inferredOperandType, - concatOp.getOperand(operandIdx)); + CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType, + concatOp.getOperand(operandIdx)); rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] { concatOp->setOperand(operandIdx, castOp->getResult(0)); }); @@ -867,8 +868,9 @@ struct InferConcatResultType : public OpRewritePattern { return failure(); } - auto newConcatOp = rewriter.create( - concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands()); + auto newConcatOp = + ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim, + concatOp->getOperands()); rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(), newConcatOp); @@ -895,7 +897,7 @@ void DimOp::getAsmResultNames(function_ref setNameFn) { void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { auto loc = result.location; - Value indexValue = builder.create(loc, index); + Value indexValue = arith::ConstantIndexOp::create(builder, loc, index); build(builder, result, source, indexValue); } @@ -1039,10 +1041,10 @@ struct DimOfReshapeOp : public OpRewritePattern { rewriter.setInsertionPointAfter(dim); Location loc = dim.getLoc(); Value extract = - rewriter.create(loc, reshape.getShape(), dim.getIndex()); + ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex()); if (extract.getType() != dim.getType()) extract = - rewriter.create(loc, dim.getType(), extract); + arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract); rewriter.replaceOp(dim, extract); return success(); } @@ -1153,8 +1155,8 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern { if (foldedTensorType == op.getType()) return failure(); - auto newOp = rewriter.create(op.getLoc(), foldedTensorType, - foldedDynamicSizes); + auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType, + foldedDynamicSizes); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } @@ -1329,8 +1331,8 @@ struct ExtractFromCollapseShape : public OpRewritePattern { SmallVector basis = llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); - auto delinearize = rewriter.create( - extractOp.getLoc(), index, basis, /*hasOuterBound=*/true); + auto delinearize = affine::AffineDelinearizeIndexOp::create( + rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true); llvm::append_range(sourceIndices, delinearize.getResults()); } if (collapseOp.getReassociationIndices().empty()) { @@ -1501,8 +1503,8 @@ struct ExtractElementFromIndexCast Type elementTy = getElementTypeOrSelf(indexCast.getIn()); - auto newExtract = rewriter.create( - loc, elementTy, indexCast.getIn(), extract.getIndices()); + auto newExtract = tensor::ExtractOp::create( + rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices()); rewriter.replaceOpWithNewOp(extract, extract.getType(), newExtract); @@ -1739,7 +1741,7 @@ struct StaticTensorGenerate : public OpRewritePattern { auto loc = generateOp.getLoc(); auto newOp = - rewriter.create(loc, foldedTensorType, foldedDynamicSizes); + GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes); rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(), newOp.getBody().begin()); rewriter.replaceOpWithNewOp(generateOp, @@ -2164,9 +2166,9 @@ struct FoldCollapseOfCastOp : public OpRewritePattern { collapseShapeOp.getSrcMutable().assign(castOp.getSource()); }); } else { - auto newOp = rewriter.create( - collapseShapeOp.getLoc(), newResultType, castOp.getSource(), - collapseShapeOp.getReassociation()); + auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(), + newResultType, castOp.getSource(), + collapseShapeOp.getReassociation()); rewriter.replaceOpWithNewOp( collapseShapeOp, collapseShapeOp.getResultType(), newOp); } @@ -2243,10 +2245,10 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { newInputShape, expandOp.getSrcType().getElementType()); auto outputType = RankedTensorType::get( newOutputShape, expandOp.getSrcType().getElementType()); - auto inputCast = rewriter.create(expandOp.getLoc(), inputType, - expandOp.getSrc()); - auto newExpand = rewriter.create( - expandOp.getLoc(), outputType, inputCast.getResult(), + auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType, + expandOp.getSrc()); + auto newExpand = ExpandShapeOp::create( + rewriter, expandOp.getLoc(), outputType, inputCast.getResult(), expandOp.getReassociationIndices(), outputOfr); rewriter.replaceOpWithNewOp(expandOp, expandOp.getType(), newExpand.getResult()); @@ -2558,10 +2560,11 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern { // Create folded extract. Location loc = sliceOp.getLoc(); - Value newResult = rewriter.create( - loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(), - sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), - sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); + Value newResult = ExtractSliceOp::create( + rewriter, loc, sliceOp.getType(), castOp.getSource(), + sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(), + sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(), + sliceOp.getStaticStrides()); rewriter.replaceOp(sliceOp, newResult); return success(); } @@ -2712,8 +2715,8 @@ struct SliceCanonicalizer { ExtractSliceOp newOp) { Value replacement = newOp.getResult(); if (replacement.getType() != op.getType()) - replacement = rewriter.create(op.getLoc(), op.getType(), - replacement); + replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(), + replacement); rewriter.replaceOp(op, replacement); } }; @@ -2981,8 +2984,8 @@ class InsertSliceOpConstantArgumentFolder final // the parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); - toInsert = rewriter.create(insertSliceOp.getLoc(), - sourceType, toInsert); + toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(), + sourceType, toInsert); } rewriter.replaceOpWithNewOp( insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, @@ -3078,17 +3081,18 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern { if (!sliceResult.isValid) return failure(); - Operation *replacement = rewriter.create( - insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(), - mixedSizes, insertSliceOp.getMixedStrides()); + Operation *replacement = + InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst, + insertSliceOp.getMixedOffsets(), mixedSizes, + insertSliceOp.getMixedStrides()); // In the parallel case there is no result and so nothing to cast. bool isParallelInsert = std::is_same::value; if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) { - replacement = rewriter.create(insertSliceOp.getLoc(), - insertSliceOp.getDestType(), - replacement->getResult(0)); + replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(), + insertSliceOp.getDestType(), + replacement->getResult(0)); } rewriter.replaceOp(insertSliceOp, replacement->getResults()); return success(); @@ -3157,8 +3161,8 @@ struct InsertSliceOpSourceCastInserter final // parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); - Value cast = rewriter.create( - insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); + Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(), + newSrcType, insertSliceOp.getSource()); rewriter.replaceOpWithNewOp( insertSliceOp, cast, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), @@ -3356,7 +3360,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, // a guard to reset the insertion point of the builder after it is destroyed. OpBuilder::InsertionGuard guard(b); b.createBlock(region, region->end(), blockArgTypes, blockArgLocs); - b.create(result.location, constantPadValue); + tensor::YieldOp::create(b, result.location, constantPadValue); } llvm::SmallBitVector PadOp::getPaddedDims() { @@ -3410,10 +3414,11 @@ struct FoldSourceTensorCast : public OpRewritePattern { padTensorOp.getSourceMutable().assign(castOp.getSource()); }); } else { - auto newOp = rewriter.create( - padTensorOp->getLoc(), newResultType, padTensorOp.getSource(), - padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), - padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(), + auto newOp = PadOp::create( + rewriter, padTensorOp->getLoc(), newResultType, + padTensorOp.getSource(), padTensorOp.getStaticLow(), + padTensorOp.getStaticHigh(), padTensorOp.getLow(), + padTensorOp.getHigh(), padTensorOp.getNofold(), getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); IRMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); @@ -3442,8 +3447,8 @@ struct FoldTargetTensorCast : public OpRewritePattern { tensorCastOp.getDest().getType())) return failure(); - auto replacementOp = rewriter.create( - padTensorOp.getLoc(), tensorCastOp.getDest().getType(), + auto replacementOp = PadOp::create( + rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(), padTensorOp.getSource(), padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(), @@ -3600,11 +3605,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs // the two paddings in one step. - auto newSliceOp = rewriter.create( - padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes, - innerSliceOp.getMixedStrides()); - auto newPadOp = rewriter.create( - padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(), + auto newSliceOp = ExtractSliceOp::create( + rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets, + newSizes, innerSliceOp.getMixedStrides()); + auto newPadOp = PadOp::create( + rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(), padOp.getMixedLowPad(), newHighPad, padOp.getNofold(), getPrunedAttributeList(padOp, PadOp::getAttributeNames())); rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), @@ -3700,9 +3705,9 @@ struct FoldStaticPadding : public OpRewritePattern { // Rewrite the op using the new static type. auto newResultType = RankedTensorType::get( newOutDims, padTensorOp.getType().getElementType()); - auto newOp = rewriter.create( - padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh, - newLows, newHighs, padTensorOp.getNofold(), + auto newOp = PadOp::create( + rewriter, padTensorOp->getLoc(), newResultType, input, staticLow, + staticHigh, newLows, newHighs, padTensorOp.getNofold(), getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); IRMapping mapper; @@ -3780,9 +3785,9 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern { SmallVector newLowPad = addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad()); - auto newPadOp = rewriter.create( - padOp.getLoc(), padOp.getResultType(), producerPad.getSource(), - newLowPad, newHighPad, padOp.getNofold(), + auto newPadOp = tensor::PadOp::create( + rewriter, padOp.getLoc(), padOp.getResultType(), + producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(), getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames())); rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), newPadOp.getRegion().begin()); @@ -3806,7 +3811,7 @@ PadOp::reifyResultShapes(OpBuilder &b, } Location loc = getLoc(); Value dim = b.createOrFold( - loc, getSource(), b.create(loc, i)); + loc, getSource(), arith::ConstantIndexOp::create(b, loc, i)); AffineExpr d0, d1, d2; bindDims(b.getContext(), d0, d1, d2); @@ -4111,8 +4116,8 @@ struct FoldTensorCastProducerOp for (auto [oldResult, newResult] : llvm::zip(op->getResults(), newOp->getResults())) { if (newResult.getType() != oldResult.getType()) { - replacements.push_back(rewriter.create( - op->getLoc(), oldResult.getType(), newResult)); + replacements.push_back(tensor::CastOp::create( + rewriter, op->getLoc(), oldResult.getType(), newResult)); } else { replacements.push_back(newResult); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 92540bd56ecbc..e6cb8bdbc472a 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -211,13 +211,13 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, if (isZeroInteger(newLength)) { hasZeroLen = true; } else if (!hasZeroLen) { - Value check = b.create( + Value check = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, getValueOrCreateConstantIndexOp(b, loc, newLength), getValueOrCreateConstantIndexOp(b, loc, zero)); dynHasZeroLenCond = dynHasZeroLenCond - ? b.create(loc, check, dynHasZeroLenCond) + ? arith::OrIOp::create(b, loc, check, dynHasZeroLenCond) : check; } @@ -241,7 +241,7 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, auto castResult = [&](Value val) -> Value { if (resultType == val.getType()) return val; - return b.create(loc, resultType, val); + return tensor::CastOp::create(b, loc, resultType, val); }; // In cases where the original data source is unused: Emit a GenerateOp and @@ -249,10 +249,10 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // have a dimension of size 0, the semantics of which is unclear.) auto createGenerateOp = [&]() { // Create GenerateOp. - auto generateOp = b.create( + auto generateOp = tensor::GenerateOp::create(b, loc, resultType, dynDims, [&](OpBuilder &builder, Location gLoc, ValueRange indices) { - builder.create(gLoc, padValue); + tensor::YieldOp::create(builder, gLoc, padValue); }); return generateOp; }; @@ -261,9 +261,9 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // the result shape of the new SliceOp has a zero dimension. auto createPadOfExtractSlice = [&]() { // Create pad(extract_slice(x)). - auto newSliceOp = b.create( + auto newSliceOp = tensor::ExtractSliceOp::create(b, loc, padOp.getSource(), newOffsets, newLengths, newStrides); - auto newPadOp = b.create( + auto newPadOp = PadOp::create(b, loc, Type(), newSliceOp, newLows, newHighs, /*nofold=*/padOp.getNofold(), getPrunedAttributeList(padOp, PadOp::getAttributeNames())); @@ -291,17 +291,17 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, Operation *thenOp; Operation *elseOp; Operation *sliceOp; - auto result = b.create( + auto result = scf::IfOp::create(b, loc, dynHasZeroLenCond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { thenOp = createGenerateOp(); - b.create(loc, castResult(thenOp->getResult(0))); + scf::YieldOp::create(b, loc, castResult(thenOp->getResult(0))); }, /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); - b.create(loc, castResult(elseOp->getResult(0))); + scf::YieldOp::create(b, loc, castResult(elseOp->getResult(0))); }); return TilingResult{ {elseOp}, SmallVector(result->getResults()), {sliceOp}}; diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 723731b8bed61..b8b1a592fdca2 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -165,7 +166,7 @@ void transform::TypeConversionCastShapeDynamicDimsOp:: if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { return Value(); } - return builder.create(loc, resultType, input).getResult(); + return tensor::CastOp::create(builder, loc, resultType, input).getResult(); }); converter.addTargetMaterialization([](OpBuilder &builder, Type resultType, ValueRange inputs, @@ -177,7 +178,7 @@ void transform::TypeConversionCastShapeDynamicDimsOp:: if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { return Value(); } - return builder.create(loc, resultType, input).getResult(); + return tensor::CastOp::create(builder, loc, resultType, input).getResult(); }); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 829b2ab92ac24..381cea7b93e41 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -223,7 +223,7 @@ struct CollapseShapeOpInterface MemRefType::get(collapseShapeOp.getSrcType().getShape(), collapseShapeOp.getSrcType().getElementType(), AffineMap(), bufferType.getMemorySpace()); - buffer = rewriter.create( + buffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(), memrefType, *tensorAlloc); } @@ -350,7 +350,7 @@ struct ExpandShapeOpInterface if (failed(buffer)) return failure(); - auto memrefExpandShape = rewriter.create( + auto memrefExpandShape = memref::ExpandShapeOp::create(rewriter, op->getLoc(), tensorResultType.getShape(), *buffer, expandShapeOp.getReassociationIndices(), expandShapeOp.getMixedOutputShape()); @@ -399,7 +399,7 @@ struct ExtractSliceOpInterface extractSliceOp.getResult(), options, state); if (failed(resultMemrefType)) return failure(); - Value subView = rewriter.create( + Value subView = memref::SubViewOp::create(rewriter, loc, llvm::cast(*resultMemrefType), *srcMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -470,7 +470,7 @@ static void createStores(RewriterBase &rewriter, Location loc, int dim, if (dim == static_cast(shape.size()) - 1) { for (int i = 0; i < shape.back(); ++i) { indices.back() = constants[i]; - rewriter.create(loc, *elementIt, buffer, indices); + memref::StoreOp::create(rewriter, loc, *elementIt, buffer, indices); ++elementIt; } return; @@ -508,7 +508,7 @@ struct FromElementsOpInterface bufferization::getBufferType(*tensorAlloc, options, state); if (failed(memrefType)) return failure(); - Value buffer = rewriter.create( + Value buffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(), *memrefType, *tensorAlloc); // Case: tensor<0xelem_type>. @@ -519,7 +519,7 @@ struct FromElementsOpInterface // Case: tensor. if (shape.empty()) { - rewriter.create( + memref::StoreOp::create(rewriter, loc, fromElementsOp.getElements().front(), buffer); replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); @@ -530,7 +530,7 @@ struct FromElementsOpInterface SmallVector constants; constants.reserve(maxDim); for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create(loc, i)); + constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i)); // Traverse all `elements` and create `memref.store` ops. auto elementIt = fromElementsOp.getElements().begin(); @@ -577,7 +577,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, // Create linalg::MapOp. OpBuilder::InsertionGuard g(rewriter); auto linalgOp = - rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), + linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(), /*init=*/tensorDestination); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); @@ -585,7 +585,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, rewriter.setInsertionPointToStart(&linalgBody); SmallVector indices; for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) - indices.push_back(rewriter.create(loc, dim)); + indices.push_back(linalg::IndexOp::create(rewriter, loc, dim)); // Move over body. rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices); @@ -645,7 +645,7 @@ struct InsertOpInterface getBuffer(rewriter, insertOp.getDest(), options, state); if (failed(destMemref)) return failure(); - rewriter.create(insertOp.getLoc(), insertOp.getScalar(), + memref::StoreOp::create(rewriter, insertOp.getLoc(), insertOp.getScalar(), *destMemref, insertOp.getIndices()); replaceOpWithBufferizedValues(rewriter, op, *destMemref); return success(); @@ -714,7 +714,7 @@ struct InsertSliceOpInterface memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getShape(), dstMemrefType, mixedOffsets, mixedSizes, mixedStrides); - Value subView = rewriter.create( + Value subView = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -797,13 +797,13 @@ struct PadOpInterface for (int64_t i = 0; i < resultType.getRank(); ++i) { if (!resultType.isDynamicDim(i)) continue; - Value srcDim = rewriter.create(loc, padOp.getSource(), i); + Value srcDim = tensor::DimOp::create(rewriter, loc, padOp.getSource(), i); Value lowPad = toValue(mixedLowPad[i]); Value highPad = toValue(mixedHighPad[i]); AffineExpr s0, s1, s2; bindSymbols(op->getContext(), s0, s1, s2); AffineExpr sumExpr = s0 + s1 + s2; - Value sum = rewriter.create( + Value sum = affine::AffineApplyOp::create(rewriter, loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); dynamicSizes.push_back(sum); } @@ -996,7 +996,7 @@ struct ParallelInsertSliceOpInterface parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), parallelInsertSliceOp.getMixedStrides()); - Value subview = rewriter.create( + Value subview = memref::SubViewOp::create(rewriter, parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), @@ -1067,13 +1067,13 @@ struct SplatOpInterface return op->emitError("memory space not implemented yet"); auto linalgOp = - rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), + linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(), /*init=*/*tensorAlloc); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); - rewriter.create(loc, splatOp.getInput()); + linalg::YieldOp::create(rewriter, loc, splatOp.getInput()); rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); return success(); @@ -1127,7 +1127,7 @@ struct ConcatOpInterface MemRefType memrefType = MemRefType::get(concatOp.getResultType().getShape(), concatOp.getResultType().getElementType(), layout); - Value dstBuffer = rewriter.create( + Value dstBuffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(), memrefType, *tensorAlloc); // Extract the dimension for the concat op @@ -1143,7 +1143,7 @@ struct ConcatOpInterface for (const auto &[dimIdx, dimSize] : llvm::enumerate(tensorType.getShape())) { if (dimSize == ShapedType::kDynamic) { - auto dimOp = rewriter.create(loc, dstBuffer, dimIdx); + auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx); sizes.push_back(dimOp.getResult()); if (dimIdx == concatDim) dynamicConcatDim = true; @@ -1158,7 +1158,7 @@ struct ConcatOpInterface if (dynamicConcatDim) { // One or more operands have dynamic size, so we must accumulate the // offset with arith ops. - dynamicOffset = rewriter.create(loc, 0); + dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); } for (auto operand : concatOp.getInputs()) { @@ -1175,7 +1175,7 @@ struct ConcatOpInterface if (dynamicConcatDim) { offsets[concatDim] = dynamicOffset.value(); - dynamicSize = rewriter.create(loc, *srcBuffer, concatDim) + dynamicSize = memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim) .getResult(); sizes[concatDim] = dynamicSize.value(); } else { @@ -1189,7 +1189,7 @@ struct ConcatOpInterface memref::SubViewOp::inferRankReducedResultType( operandTensorType.getShape(), dstMemrefType, offsets, sizes, strides); - Value subview = rewriter.create( + Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, dstBuffer, offsets, sizes, strides); // Copy the source buffer into the destination subview. @@ -1197,7 +1197,7 @@ struct ConcatOpInterface return failure(); if (dynamicConcatDim) { - dynamicOffset = rewriter.create( + dynamicOffset = arith::AddIOp::create(rewriter, loc, dynamicOffset.value(), dynamicSize.value()); } else { concatDimOffset += operandConcatDimSize; diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp index fa748cf01977f..5b28791126bd7 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -43,7 +43,7 @@ struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { // Create new tensor.empty op. // TODO: Do not drop tensor type encoding. - Value emptyTensor = rewriter.create( + Value emptyTensor = EmptyOp::create(rewriter, loc, resultShapes[0], reshapeOp.getResultType().getElementType()); if (emptyTensor.getType() != reshapeOp.getResultType()) { rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp index e0acaee9f6626..50c72c329ed95 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -55,7 +55,7 @@ static ValueRange invertCollapseShapeIndexing( for (int64_t i : reassociation[dim]) basis.push_back(reshapeSourceShape[i]); auto delinearized = - b.create(loc, indexValue, basis); + AffineDelinearizeIndexOp::create(b, loc, indexValue, basis); return delinearized->getResults(); } @@ -144,14 +144,14 @@ tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody( SmallVector extractParams = helper.getExtractSliceParams(builder.getContext(), multiIndices); - Value subTileResult = builder.create( + Value subTileResult = tensor::ExtractSliceOp::create(builder, loc, collapseShapeOp.getSrc(), extractParams); SmallVector insertParams = helper.getInsertSliceParams(builder.getContext(), tileInductionVars); // Collapse the dimensions of the source slice back down. - Value collapsedResult = builder.create( + Value collapsedResult = tensor::CollapseShapeOp::create(builder, loc, subTileResult, reassociationIndices); return std::make_pair(collapsedResult, insertParams); } @@ -175,7 +175,7 @@ tensor::simplifyCollapseShapeWithRankReducingExtractSlice( SmallVector sizes = tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc()); SmallVector strides(sourceType.getRank(), one); - auto sliceOp = rewriter.create( + auto sliceOp = tensor::ExtractSliceOp::create(rewriter, op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides); if (!info->newReassociationIndices.has_value()) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp index a787b485f7162..fd40ece63276e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -115,7 +115,7 @@ TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp( extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), indices, sourceIndices); - Operation *newOp = rewriter.create( + Operation *newOp = vector::TransferReadOp::create(rewriter, readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices, AffineMapAttr::get(expandDimsToRank( diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp index 4655fa3cf0d23..e3a2854038bdd 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp @@ -65,7 +65,7 @@ FailureOr tensor::buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, return padOp.getResult(); // Create a new tensor::PadOp. - auto newPadOp = b.create( + auto newPadOp = PadOp::create(b, loc, padOp.getResultType(), padOp.getSource(), newMixedLow, newMixedHigh, constantPadding, padOp.getNofold(), /*attrs=*/ArrayRef{}); @@ -84,7 +84,7 @@ FailureOr tensor::buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, offsets.push_back(b.getIndexAttr(0)); } else { offsets.push_back( - b.create( + affine::AffineApplyOp::create(b, loc, b.getAffineDimExpr(0) - b.getAffineDimExpr(1), std::initializer_list{cast(newMixedLow[i]), cast(prevLow)}) @@ -100,7 +100,7 @@ FailureOr tensor::buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, strides.push_back(b.getIndexAttr(1)); } - return b.create(loc, newPadOp, offsets, sizes, strides) + return ExtractSliceOp::create(b, loc, newPadOp, offsets, sizes, strides) .getResult(); } @@ -125,7 +125,7 @@ FailureOr tensor::buildIndependentOp(OpBuilder &b, // Create a new tensor::EmptyOp. Value newEmptyOp = - b.create(loc, newSizes, emptyOp.getType().getElementType()); + EmptyOp::create(b, loc, newSizes, emptyOp.getType().getElementType()); // Create a tensor::ExtractSliceOp. SmallVector offsets(newSizes.size(), b.getIndexAttr(0)); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 657624b817af2..247ea15640b39 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -77,7 +77,7 @@ struct FoldUnPaddingCollapseIntoExtract return rewriter.notifyMatchFailure(collapseShapeOp, "expected unpadding collapse"); - Value unPaddedExtractSlice = rewriter.create( + Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(), extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); @@ -271,7 +271,7 @@ struct BubbleUpExpandThroughParallelCollapse // matches the number of dimensions of the result, then the expand_shape // is a no-op. if (newExpandReInds.size() != newExpandSizes.size()) { - newCollapseSrc = rewriter.create( + newCollapseSrc = tensor::ExpandShapeOp::create(rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds, newExpandSizes); } @@ -281,7 +281,7 @@ struct BubbleUpExpandThroughParallelCollapse // is a no-op. Value replacement = newCollapseSrc; if (newCollapseReInds.size() != newExpandSizes.size()) { - replacement = rewriter.create( + replacement = tensor::CollapseShapeOp::create(rewriter, loc, newCollapseSrc, newCollapseReInds); } rewriter.replaceOp(expandOp, replacement); @@ -406,7 +406,7 @@ struct BubbleUpExpandShapeThroughExtractSlice shape, expandShapeOp.getResultType().getElementType()); // Create a new ExtractSliceOp and ExpandShapeOp. - Value newSliceOp = rewriter.create( + Value newSliceOp = tensor::ExtractSliceOp::create(rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, collapsedStrides); rewriter.replaceOpWithNewOp( @@ -736,7 +736,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice groupExpandedOffsets.rend()); } - Value newSliceOp = rewriter.create( + Value newSliceOp = tensor::ExtractSliceOp::create(rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, expandedSizes, expandedStrides); rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp index 7c9fced540adb..69e649d2eebe8 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp @@ -196,7 +196,7 @@ struct PadOpToConstant final : public OpRewritePattern { "tensor type not supported"); if (newOp.getType() != resultType) - newOp = rewriter.create(loc, resultType, newOp); + newOp = tensor::CastOp::create(rewriter, loc, resultType, newOp); rewriter.replaceOp(padTensorOp, newOp); return success(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 6138821ee8c61..8522c3ed28745 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -48,12 +48,12 @@ struct CastOpInterface if (isa(srcType)) { // Check rank. - Value srcRank = builder.create(loc, castOp.getSource()); + Value srcRank = RankOp::create(builder, loc, castOp.getSource()); Value resultRank = - builder.create(loc, resultType.getRank()); - Value isSameRank = builder.create( + arith::ConstantIndexOp::create(builder, loc, resultType.getRank()); + Value isSameRank = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); - builder.create( + cf::AssertOp::create(builder, loc, isSameRank, RuntimeVerifiableOpInterface::generateErrorMessage(op, "rank mismatch")); @@ -71,12 +71,12 @@ struct CastOpInterface continue; Value srcDimSz = - builder.create(loc, castOp.getSource(), it.index()); + DimOp::create(builder, loc, castOp.getSource(), it.index()); Value resultDimSz = - builder.create(loc, it.value()); - Value isSameSz = builder.create( + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameSz = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); - builder.create( + cf::AssertOp::create(builder, loc, isSameSz, RuntimeVerifiableOpInterface::generateErrorMessage( op, "size mismatch of dim " + std::to_string(it.index()))); @@ -90,9 +90,9 @@ struct DimOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto dimOp = cast(op); - Value rank = builder.create(loc, dimOp.getSource()); - Value zero = builder.create(loc, 0); - builder.create( + Value rank = RankOp::create(builder, loc, dimOp.getSource()); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + cf::AssertOp::create(builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), RuntimeVerifiableOpInterface::generateErrorMessage( op, "index is out of bounds")); @@ -125,7 +125,7 @@ struct ExtractInsertOpInterface } auto indices = extractInsertOp.getIndices(); - auto zero = builder.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); Value assertCond; for (auto i : llvm::seq(0, rank)) { Value dimOp = builder.createOrFold(loc, tensor, i); @@ -135,7 +135,7 @@ struct ExtractInsertOpInterface i > 0 ? builder.createOrFold(loc, assertCond, inBounds) : inBounds; } - builder.create( + cf::AssertOp::create(builder, loc, assertCond, RuntimeVerifiableOpInterface::generateErrorMessage( op, "out-of-bounds access")); @@ -153,8 +153,8 @@ struct ExtractSliceOpInterface // For each dimension, assert that: // 0 <= offset < dim_size // 0 <= offset + (size - 1) * stride < dim_size - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { Value offset = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedOffsets()[i]); @@ -168,20 +168,20 @@ struct ExtractSliceOpInterface loc, extractSliceOp.getSource(), i); Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - builder.create( + cf::AssertOp::create(builder, loc, offsetInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "offset " + std::to_string(i) + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = builder.create(loc, size, one); + Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = - builder.create(loc, sizeMinusOne, stride); + arith::MulIOp::create(builder, loc, sizeMinusOne, stride); Value lastPos = - builder.create(loc, offset, sizeMinusOneTimesStride); + arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - builder.create( + cf::AssertOp::create(builder, loc, lastPosInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp index d50d7c62b789c..21f70e52b8706 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp @@ -59,7 +59,7 @@ struct InsertSliceLikeOpSubsetInsertionOpInterface Value buildSubsetExtraction(Operation *op, OpBuilder &builder, Location loc) const { auto insertSliceOp = cast(op); - auto extractOp = builder.create( + auto extractOp = tensor::ExtractSliceOp::create(builder, loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 289296a07d9d3..728feab6a6d13 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -58,7 +58,7 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source, high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1, {outDim, sourceDim}); } - return b.create(loc, resType, source, low, high, pad, nofold); + return PadOp::create(b, loc, resType, source, low, high, pad, nofold); } SmallVector mlir::tensor::createDynamicDimValues(OpBuilder &b, @@ -69,7 +69,7 @@ SmallVector mlir::tensor::createDynamicDimValues(OpBuilder &b, for (const auto &en : llvm::enumerate(tensorTy.getShape())) { if (en.value() == ShapedType::kDynamic) dynamicDims.push_back( - b.create(loc, rankedTensor, en.index())); + tensor::DimOp::create(b, loc, rankedTensor, en.index())); } return dynamicDims; } @@ -121,7 +121,7 @@ mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src, reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end())); nextDimToGroup = setBit + 1; } - return b.create(loc, src, reassocMaps); + return tensor::CollapseShapeOp::create(b, loc, src, reassocMaps); } bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 1d21096e8920b..477d35fc8d9c5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -851,7 +851,7 @@ struct PadSliceOptimization : public OpRewritePattern { getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings); auto newPadTy = RankedTensorType::get(newPadShape, inputTy.getElementType()); - auto newPadOp = rewriter.create( + auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp, padOp.getPadConst()); @@ -903,7 +903,7 @@ struct SliceDynamicSizeCanonicalization } auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); - auto newSliceOp = rewriter.create( + auto newSliceOp = tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(), sliceOp.getStart(), size_op); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 5170a11523845..48ae70c51c279 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -182,11 +182,11 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (llvm::isa(type) && llvm::isa(value)) { - return builder.create( + return tosa::ConstShapeOp::create(builder, loc, type, llvm::cast(value)); } if (llvm::isa(value)) - return builder.create(loc, type, + return tosa::ConstOp::create(builder, loc, type, llvm::cast(value)); return nullptr; } @@ -325,7 +325,7 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc, builder.getFloatAttr(srcElemType, val)) : DenseElementsAttr::get(padConstEType, builder.getIntegerAttr(srcElemType, val))}; - return builder.create(loc, padConstType, padConstAttr); + return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr); } //===----------------------------------------------------------------------===// @@ -2417,7 +2417,7 @@ LogicalResult TransposeOp::reifyResultShapes( int32_t dimInInput = transposePerms[dim]; if (inputType.isDynamicDim(dimInInput)) returnedDims[dim] = - builder.create(getLoc(), input, dimInInput) + tensor::DimOp::create(builder, getLoc(), input, dimInInput) .getResult(); else returnedDims[dim] = @@ -3948,12 +3948,12 @@ std::optional mlir::tosa::createZeroPointTensor(OpBuilder &builder, if (llvm::isa(srcElemType)) { auto zpAttr = DenseElementsAttr::get( zpType, builder.getFloatAttr(srcElemType, static_cast(zp))); - return builder.create(loc, zpType, zpAttr); + return tosa::ConstOp::create(builder, loc, zpType, zpAttr); } if (llvm::isa(srcElemType)) { auto zpAttr = DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp)); - return builder.create(loc, zpType, zpAttr); + return tosa::ConstOp::create(builder, loc, zpType, zpAttr); } llvm::errs() << "zero point is not allowed for unsupported data types\n"; return std::nullopt; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 9b4cf85c480d3..72e3c0590183c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -91,12 +91,12 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { if (inputETy != resultETy) { inputType = inputType.clone(resultETy); - input = rewriter.create(op.getLoc(), inputType, input); + input = tosa::CastOp::create(rewriter, op.getLoc(), inputType, input); } if (weightETy != resultETy) { weightType = weightType.clone(resultETy); - weight = rewriter.create(op.getLoc(), weightType, weight); + weight = tosa::CastOp::create(rewriter, op.getLoc(), weightType, weight); } if (iZp != 0 || wZp != 0) { @@ -110,8 +110,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { auto zpTy = RankedTensorType::get(shape, ety); auto zpAttr = DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); - auto zpVal = rewriter.create(op.getLoc(), zpTy, zpAttr); - return rewriter.create(op.getLoc(), val.getType(), val, + auto zpVal = tosa::ConstOp::create(rewriter, op.getLoc(), zpTy, zpAttr); + return tosa::SubOp::create(rewriter, op.getLoc(), val.getType(), val, zpVal); }; @@ -139,9 +139,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { auto padTy = RankedTensorType::get({1}, inputETy); auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); Value padVal = - rewriter.create(op->getLoc(), padTy, padAttr); + tosa::ConstOp::create(rewriter, op->getLoc(), padTy, padAttr); inputType = RankedTensorType::get(newShape, inputETy); - input = rewriter.create(op->getLoc(), inputType, input, + input = tosa::PadOp::create(rewriter, op->getLoc(), inputType, input, padSizeVal, padVal); } @@ -162,7 +162,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { auto shiftZeroAttr = DenseElementsAttr::get( shiftType, rewriter.getIntegerAttr(shiftElementType, 0)); Value constZero = - rewriter.create(op.getLoc(), shiftType, shiftZeroAttr); + tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr); Value mulValue = rewriter .create(op.getLoc(), mulShapeType, input, weight, constZero) @@ -175,7 +175,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { dyn_cast(input.getType()).getElementType()); auto outputShapeValue = getTosaConstShape(rewriter, op->getLoc(), outputShape); - Value outputValue = rewriter.create( + Value outputValue = tosa::ReshapeOp::create(rewriter, op.getLoc(), outputShapeType, mulValue, outputShapeValue); Value bias = op.getBias(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index ea6ac981b53cc..12af0ce56a8bc 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -64,12 +64,12 @@ class TransposeConvNonStridedConverter convPad[2] = kernelWidth - 1 + pad[2]; convPad[3] = kernelWidth - 1 + pad[3]; - auto reverse1 = rewriter.create( + auto reverse1 = tosa::ReverseOp::create(rewriter, loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1)); - auto reverse2 = rewriter.create( + auto reverse2 = tosa::ReverseOp::create(rewriter, loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); - Value conv2d = rewriter.create( + Value conv2d = tosa::Conv2DOp::create(rewriter, loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(), rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), @@ -218,7 +218,7 @@ class TransposeConvStridedConverter inputPaddingVal, inputPadConst); // We use a zero bias as we need to broadcast the bias. - auto zeroBias = rewriter.create( + auto zeroBias = tosa::ConstOp::create(rewriter, loc, RankedTensorType::get({outputChannels * stride[0] * stride[1]}, biasETy), diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp index 29ec9f8db2615..9a648015d23df 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -114,7 +114,7 @@ class TypeModificationState { OpBuilder builder{value.getContext()}; builder.setInsertionPointAfter(value.getDefiningOp()); castValue = - builder.create(value.getLoc(), oldType, value); + tensor::CastOp::create(builder, value.getLoc(), oldType, value); } use->set(castValue); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 7f85cd52f6bde..51a093bd0cc38 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -405,7 +405,7 @@ std::optional TosaReduceTransposes::buildMappedToValue( return std::nullopt; } ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter); - auto foldedReshape = rewriter.create( + auto foldedReshape = ReshapeOp::create(rewriter, reshapeOp.getLoc(), RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms), reshapeOutputType.getElementType()), @@ -425,7 +425,7 @@ std::optional TosaReduceTransposes::buildMappedToValue( if (!maybeNewDenseAttr.has_value()) return std::nullopt; auto newDenseAttr = maybeNewDenseAttr.value(); - auto newConstOp = rewriter.create( + auto newConstOp = ConstOp::create(rewriter, constOp.getLoc(), newDenseAttr.getType(), newDenseAttr); return newConstOp->getResult(0); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp index 3b697a2ee3e47..677d8e9904a67 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp @@ -37,7 +37,7 @@ void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) { if (inputs.size() != 1) return Value(); - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType, @@ -46,7 +46,7 @@ void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) { if (inputs.size() != 1) return Value(); - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); } diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index 9844abcc34cb1..69ef0ba40d72e 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -33,18 +33,18 @@ mlir::tosa::condenseValues(const SmallVector &values) { Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter) { - Value minValue = rewriter.create(loc, arg, max); - return rewriter.create(loc, minValue, min); + Value minValue = arith::MinimumFOp::create(rewriter, loc, arg, max); + return arith::MaximumFOp::create(rewriter, loc, minValue, min); } Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned) { if (isUnsigned) { - auto minOrArg = rewriter.create(loc, min, arg); - return rewriter.create(loc, max, minOrArg); + auto minOrArg = arith::MaxUIOp::create(rewriter, loc, min, arg); + return arith::MinUIOp::create(rewriter, loc, max, minOrArg); } - auto minOrArg = rewriter.create(loc, min, arg); - return rewriter.create(loc, max, minOrArg); + auto minOrArg = arith::MaxSIOp::create(rewriter, loc, min, arg); + return arith::MinSIOp::create(rewriter, loc, max, minOrArg); } bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { @@ -144,7 +144,7 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape); - auto reshapeLower = builder.create( + auto reshapeLower = tosa::ReshapeOp::create(builder, reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue); if (input1Rank > input2Rank) { @@ -162,7 +162,7 @@ Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef shape) { auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size()); - mlir::Operation *mlir_op = builder.create(type, attr); + mlir::Operation *mlir_op = tosa::ConstShapeOp::create(builder, type, attr); return mlir_op->getResult(0); } diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp index d69535169f956..003d81fbc7771 100644 --- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp +++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp index 12257da878a40..2d6d4e17aff66 100644 --- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "llvm/Support/InterleavedRange.h" diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 0db0317461c03..d4a9d4ef1d656 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp index 93b0bc591ca02..15c2ae22f56ea 100644 --- a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp index 34d6221d15fb0..fc6ea93b49218 100644 --- a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp index 85f61245eb734..9d5c749f02dfd 100644 --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Rewrite/PatternApplicator.h" diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp index 5b2cfe7bf4264..442a53ca8162a 100644 --- a/mlir/lib/Dialect/UB/IR/UBOps.cpp +++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc" @@ -52,7 +53,7 @@ void UBDialect::initialize() { Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto attr = dyn_cast(value)) - return builder.create(loc, type, attr); + return PoisonOp::create(builder, loc, type, attr); return nullptr; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 14e626a6b23e3..a8277a2a805e9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -125,7 +126,7 @@ static MaskFormat getMaskFormat(Value mask) { /// Default callback to build a region with a 'vector.yield' terminator with no /// arguments. void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { - builder.create(loc); + vector::YieldOp::create(builder, loc); } // Helper for verifying combining kinds in contractions and reductions. @@ -597,16 +598,16 @@ struct ElideUnitDimsInMultiDimReduction VectorType newMaskType = VectorType::get(dstVecType.getShape(), rewriter.getI1Type(), dstVecType.getScalableDims()); - mask = rewriter.create(loc, newMaskType, mask); + mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask); } - cast = rewriter.create( - loc, reductionOp.getDestType(), reductionOp.getSource()); + cast = vector::ShapeCastOp::create( + rewriter, loc, reductionOp.getDestType(), reductionOp.getSource()); } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. if (mask) - mask = rewriter.create(loc, mask); - cast = rewriter.create(loc, reductionOp.getSource()); + mask = vector::ExtractOp::create(rewriter, loc, mask); + cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource()); } Value result = @@ -673,36 +674,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, switch (op) { case arith::AtomicRMWKind::addf: case arith::AtomicRMWKind::addi: - return builder.create(vector.getLoc(), - CombiningKind::ADD, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::ADD, vector); case arith::AtomicRMWKind::mulf: case arith::AtomicRMWKind::muli: - return builder.create(vector.getLoc(), - CombiningKind::MUL, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MUL, vector); case arith::AtomicRMWKind::minimumf: - return builder.create(vector.getLoc(), - CombiningKind::MINIMUMF, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINIMUMF, vector); case arith::AtomicRMWKind::mins: - return builder.create(vector.getLoc(), - CombiningKind::MINSI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINSI, vector); case arith::AtomicRMWKind::minu: - return builder.create(vector.getLoc(), - CombiningKind::MINUI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINUI, vector); case arith::AtomicRMWKind::maximumf: - return builder.create(vector.getLoc(), - CombiningKind::MAXIMUMF, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXIMUMF, vector); case arith::AtomicRMWKind::maxs: - return builder.create(vector.getLoc(), - CombiningKind::MAXSI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXSI, vector); case arith::AtomicRMWKind::maxu: - return builder.create(vector.getLoc(), - CombiningKind::MAXUI, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXUI, vector); case arith::AtomicRMWKind::andi: - return builder.create(vector.getLoc(), - CombiningKind::AND, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::AND, vector); case arith::AtomicRMWKind::ori: - return builder.create(vector.getLoc(), - CombiningKind::OR, vector); + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::OR, vector); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); @@ -741,8 +742,8 @@ struct ElideSingleElementReduction : public OpRewritePattern { Location loc = reductionOp.getLoc(); if (mask) - mask = rewriter.create(loc, mask); - Value result = rewriter.create(loc, reductionOp.getVector()); + mask = ExtractOp::create(rewriter, loc, mask); + Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector()); if (Value acc = reductionOp.getAcc()) result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), @@ -4171,8 +4172,8 @@ class StridedSliceBroadcast final // just a single scalar. bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1); if (!lowerDimMatch && !isScalarSrc) { - source = rewriter.create( - op->getLoc(), source, + source = ExtractStridedSliceOp::create( + rewriter, op->getLoc(), source, getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff), getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff)); @@ -4266,8 +4267,8 @@ class ContiguousExtractStridedSliceToExtract final SmallVector offsets = getI64SubArray(op.getOffsets()); auto extractOffsets = ArrayRef(offsets).take_front(numOffsets); - Value extract = rewriter.create(op->getLoc(), source, - extractOffsets); + Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source, + extractOffsets); rewriter.replaceOpWithNewOp(op, op.getType(), extract); return success(); } @@ -4297,7 +4298,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, Type elemType = llvm::cast(source.getType()).getElementType(); if (!padding) - padding = builder.create(result.location, elemType); + padding = ub::PoisonOp::create(builder, result.location, elemType); build(builder, result, vectorType, source, indices, permutationMapAttr, *padding, /*mask=*/Value(), inBoundsAttr); } @@ -4315,7 +4316,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, SmallVector(vectorType.getRank(), false)); Type elemType = llvm::cast(source.getType()).getElementType(); if (!padding) - padding = builder.create(result.location, elemType); + padding = ub::PoisonOp::create(builder, result.location, elemType); build(builder, result, vectorType, source, indices, *padding, permutationMapAttr, inBoundsAttr); } @@ -4334,7 +4335,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, SmallVector(vectorType.getRank(), false)); Type elemType = llvm::cast(source.getType()).getElementType(); if (!padding) - padding = builder.create(result.location, elemType); + padding = ub::PoisonOp::create(builder, result.location, elemType); build(builder, result, vectorType, source, indices, permutationMapAttr, *padding, /*mask=*/Value(), inBoundsAttr); @@ -4859,7 +4860,7 @@ struct TransferReadAfterWriteToBroadcast VectorType broadcastedType = VectorType::get( broadcastShape, defWrite.getVectorType().getElementType(), broadcastScalableFlags); - vec = rewriter.create(loc, broadcastedType, vec); + vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec); SmallVector transposePerm(permutation.begin(), permutation.end()); rewriter.replaceOpWithNewOp(readOp, vec, transposePerm); @@ -5337,13 +5338,14 @@ struct SwapExtractSliceOfTransferWrite // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp. // Set all in_bounds to false and let the folder infer them. SmallVector newInBounds(vectorShape.size(), false); - auto newExtractOp = rewriter.create( - extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(), - insertOp.getMixedOffsets(), insertOp.getMixedSizes(), - insertOp.getMixedStrides()); - auto newTransferWriteOp = rewriter.create( - transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(), - transferOp.getIndices(), transferOp.getPermutationMapAttr(), + auto newExtractOp = tensor::ExtractSliceOp::create( + rewriter, extractOp.getLoc(), insertOp.getSourceType(), + insertOp.getDest(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + auto newTransferWriteOp = TransferWriteOp::create( + rewriter, transferOp.getLoc(), transferOp.getVector(), + newExtractOp.getResult(), transferOp.getIndices(), + transferOp.getPermutationMapAttr(), rewriter.getBoolArrayAttr(newInBounds)); rewriter.modifyOpInPlace(insertOp, [&]() { insertOp.getSourceMutable().assign(newTransferWriteOp.getResult()); @@ -6867,7 +6869,7 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { OpBuilder opBuilder(builder.getContext()); Operation *maskedOp = &block.front(); opBuilder.setInsertionPointToEnd(&block); - opBuilder.create(loc, maskedOp->getResults()); + vector::YieldOp::create(opBuilder, loc, maskedOp->getResults()); } LogicalResult MaskOp::verify() { @@ -7202,7 +7204,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder, // Create a block and move the op to that block. insBlock->getOperations().splice( insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp); - builder.create(maskableOp->getLoc(), maskableOp->getResults()); + YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults()); } /// Creates a vector.mask operation around a maskable operation. Returns the @@ -7214,12 +7216,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder, if (!mask) return maskableOp; if (passthru) - return builder.create(maskableOp->getLoc(), - maskableOp->getResultTypes(), mask, passthru, - maskableOp, createMaskOpRegion); - return builder.create(maskableOp->getLoc(), - maskableOp->getResultTypes(), mask, maskableOp, - createMaskOpRegion); + return MaskOp::create(builder, maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, passthru, + maskableOp, createMaskOpRegion); + return MaskOp::create(builder, maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, maskableOp, + createMaskOpRegion); } /// Creates a vector select operation that picks values from `newValue` or @@ -7234,8 +7236,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, if (!mask) return newValue; - return builder.create(newValue.getLoc(), newValue.getType(), - mask, newValue, passthru); + return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(), + mask, newValue, passthru); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 125c3d918284c..bb051650897b4 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 9da051150e409..c750017a6b504 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -116,7 +116,7 @@ struct TransferWriteOpInterface getBuffer(rewriter, writeOp.getBase(), options, state); if (failed(resultBuffer)) return failure(); - rewriter.create( + vector::TransferWriteOp::create(rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(), writeOp.getMask(), writeOp.getInBoundsAttr()); @@ -241,7 +241,7 @@ struct MaskOpInterface // Create a new vector.mask op. ValueRange newYieldedValuesRange(newYieldedValues); TypeRange newResultTypes(newYieldedValuesRange); - auto newOp = rewriter.create( + auto newOp = vector::MaskOp::create(rewriter, op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(), /*maskableOp=*/nullptr, /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {}); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp index 89930a6bd35fa..4c3a04cfb5bfa 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp @@ -64,14 +64,14 @@ class UnrollBitCastOp final : public OpRewritePattern { VectorType::get(shape, resultType.getElementType(), scalableDims); Location loc = op.getLoc(); - Value result = rewriter.create(loc, resultType); + Value result = ub::PoisonOp::create(rewriter, loc, resultType); for (auto position : *unrollIterator) { Value extract = - rewriter.create(loc, op.getSource(), position); + vector::ExtractOp::create(rewriter, loc, op.getSource(), position); Value bitcast = - rewriter.create(loc, bitcastResType, extract); + vector::BitCastOp::create(rewriter, loc, bitcastResType, extract); result = - rewriter.create(loc, bitcast, result, position); + vector::InsertOp::create(rewriter, loc, bitcast, result, position); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index 11dcfe421e0c4..cb8e566869cfd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -52,7 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern { // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. if (srcRank <= 1 && dstRank == 1) { - Value ext = rewriter.create(loc, op.getSource()); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); rewriter.replaceOpWithNewOp(op, dstType, ext); return success(); } @@ -70,10 +70,10 @@ class BroadcastOpLowering : public OpRewritePattern { // Duplication. VectorType resType = VectorType::Builder(dstType).dropDim(0); Value bcst = - rewriter.create(loc, resType, op.getSource()); - Value result = rewriter.create(loc, dstType); + vector::BroadcastOp::create(rewriter, loc, resType, op.getSource()); + Value result = ub::PoisonOp::create(rewriter, loc, dstType); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) - result = rewriter.create(loc, bcst, result, d); + result = vector::InsertOp::create(rewriter, loc, bcst, result, d); rewriter.replaceOp(op, result); return success(); } @@ -111,13 +111,13 @@ class BroadcastOpLowering : public OpRewritePattern { VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType, dstType.getScalableDims().drop_front()); - Value result = rewriter.create(loc, dstType); + Value result = ub::PoisonOp::create(rewriter, loc, dstType); if (m == 0) { // Stetch at start. - Value ext = rewriter.create(loc, op.getSource(), 0); - Value bcst = rewriter.create(loc, resType, ext); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0); + Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) - result = rewriter.create(loc, bcst, result, d); + result = vector::InsertOp::create(rewriter, loc, bcst, result, d); } else { // Stetch not at start. if (dstType.getScalableDims()[0]) { @@ -125,9 +125,9 @@ class BroadcastOpLowering : public OpRewritePattern { return failure(); } for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { - Value ext = rewriter.create(loc, op.getSource(), d); - Value bcst = rewriter.create(loc, resType, ext); - result = rewriter.create(loc, bcst, result, d); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d); + Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext); + result = vector::InsertOp::create(rewriter, loc, bcst, result, d); } } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index c6627b5ec0d77..b2486c5cbbcf8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -90,17 +90,17 @@ static Value reshapeLoad(Location loc, Value val, VectorType type, // At extraction dimension? if (index == 0) - return rewriter.create(loc, val, pos); + return vector::ExtractOp::create(rewriter, loc, val, pos); // Unroll leading dimensions. VectorType vType = VectorType::Builder(type).dropDim(0); VectorType resType = VectorType::Builder(type).dropDim(index); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, val, d); + Value ext = vector::ExtractOp::create(rewriter, loc, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, load, result, d); + result = vector::InsertOp::create(rewriter, loc, load, result, d); } return result; } @@ -115,15 +115,15 @@ static Value reshapeStore(Location loc, Value val, Value result, return val; // At insertion dimension? if (index == 0) - return rewriter.create(loc, val, result, pos); + return vector::InsertOp::create(rewriter, loc, val, result, pos); // Unroll leading dimensions. VectorType vType = VectorType::Builder(type).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, result, d); - Value ins = rewriter.create(loc, val, d); + Value ext = vector::ExtractOp::create(rewriter, loc, result, d); + Value ins = vector::ExtractOp::create(rewriter, loc, val, d); Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, sto, result, d); + result = vector::InsertOp::create(rewriter, loc, sto, result, d); } return result; } @@ -141,7 +141,7 @@ createContractArithOp(Location loc, Value x, Value y, Value acc, kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF) // Only valid for floating point types. return std::nullopt; - mul = rewriter.create(loc, x, y); + mul = arith::MulIOp::create(rewriter, loc, x, y); } else { // Float case. if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || @@ -152,14 +152,14 @@ createContractArithOp(Location loc, Value x, Value y, Value acc, return std::nullopt; // Special case for fused multiply-add. if (acc && isa(acc.getType()) && kind == CombiningKind::ADD) { - Value fma = rewriter.create(loc, x, y, acc); + Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc); if (mask) // The fma op doesn't need explicit masking. However, fma ops used in // reductions must preserve previous 'acc' values for masked-out lanes. fma = selectPassthru(rewriter, mask, fma, acc); return fma; } - mul = rewriter.create(loc, x, y); + mul = arith::MulFOp::create(rewriter, loc, x, y); } if (!acc) @@ -195,8 +195,8 @@ static std::optional getDimPosition(AffineMap map, unsigned dim) { static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter) { if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); + return arith::AddIOp::create(rewriter, loc, x, y); + return arith::AddFOp::create(rewriter, loc, x, y); } /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using @@ -204,8 +204,8 @@ static Value createAdd(Location loc, Value x, Value y, bool isInt, static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter) { if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); + return arith::MulIOp::create(rewriter, loc, x, y); + return arith::MulFOp::create(rewriter, loc, x, y); } namespace { @@ -411,7 +411,7 @@ struct UnrolledOuterProductGenerator Value t(Value v, ArrayRef perm = {1, 0}) { if (!v) return v; - return rewriter.create(loc, v, perm); + return vector::TransposeOp::create(rewriter, loc, v, perm); } Value promote(Value v, Type dstElementType) { @@ -425,8 +425,8 @@ struct UnrolledOuterProductGenerator if (vecType) promotedType = vecType.clone(promotedType); if (isa(dstElementType)) - return rewriter.create(loc, promotedType, v); - return rewriter.create(loc, promotedType, v); + return arith::ExtFOp::create(rewriter, loc, promotedType, v); + return arith::ExtSIOp::create(rewriter, loc, promotedType, v); } FailureOr outerProd(Value lhs, Value rhs, Value res, @@ -438,16 +438,16 @@ struct UnrolledOuterProductGenerator Type resElementType = cast(res.getType()).getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { - Value extractA = rewriter.create(loc, lhs, k); - Value extractB = rewriter.create(loc, rhs, k); + Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k); + Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k); extractA = promote(extractA, resElementType); extractB = promote(extractB, resElementType); Value extractMask; if (maybeMask.has_value() && maybeMask.value()) extractMask = - rewriter.create(loc, maybeMask.value(), k); + vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k); - Operation *outerProdOp = rewriter.create( + Operation *outerProdOp = vector::OuterProductOp::create(rewriter, loc, res.getType(), extractA, extractB, res, kind); res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); } @@ -698,28 +698,28 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( // Two outer parallel, one inner reduction (matmat flavor). // if (maps == infer({{m, k}, {k, n}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { // No need to permute anything. } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); - rhs = rewriter.create(loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { // This is the classical row-major matmul. Just permute the lhs. Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); rhs = tmp; } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); - rhs = rewriter.create(loc, tmp, perm); + lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); lhs = tmp; } else { return failure(); @@ -732,12 +732,12 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( if (maps == infer({{m, n}, {n}, {m}})) { // No need to permute anything. } else if (maps == infer({{n, m}, {n}, {m}})) { - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{n}, {m, n}, {m}})) { std::swap(lhs, rhs); } else if (maps == infer({{n}, {n, m}, {m}})) { std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else { return failure(); } @@ -754,31 +754,31 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; // ExtractOp does not allow dynamic indexing, we must unroll explicitly. - Value res = rewriter.create(loc, dstType, + Value res = arith::ConstantOp::create(rewriter, loc, dstType, rewriter.getZeroAttr(dstType)); bool isInt = isa(dstType.getElementType()); llvm::SmallVector extractedCols; extractedCols.reserve(dstColumns); for (unsigned r = 0; r < dstRows; ++r) { - Value rowLhs = rewriter.create(op.getLoc(), lhs, r); + Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { // Extract each respective row and column of the LHS and RHS once to // avoid having duplicate SSA values pointing to the same rows/columns. if (r == 0) { Value colRhs = rank == 1 ? rhs - : rewriter.create(op.getLoc(), rhs, c); + : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c); extractedCols.push_back(colRhs); } Value extractedColRhs = extractedCols[c]; Value product = createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter); - Value sum = rewriter.create( + Value sum = vector::ReductionOp::create(rewriter, op.getLoc(), vector::CombiningKind::ADD, product); SmallVector pos = rank == 1 ? SmallVector{r} : SmallVector{r, c}; - res = rewriter.create(op.getLoc(), sum, res, pos); + res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos); } } if (auto acc = op.getAcc()) @@ -879,21 +879,21 @@ struct ContractOpToElementwise lhsDims.append(lhsShape.begin(), lhsShape.end()); auto expandedType = VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); - newLhs = rewriter.create(loc, expandedType, newLhs); + newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs); } if (!rhsDims.empty()) { rhsDims.append(rhsShape.begin(), rhsShape.end()); auto expandedType = VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); - newRhs = rewriter.create(loc, expandedType, newRhs); + newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs); } bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); - newLhs = rewriter.create(loc, newLhs, lhsTranspose); - newRhs = rewriter.create(loc, newRhs, rhsTranspose); + newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose); + newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose); SmallVector lhsOffsets(lhsReductionDims.size(), 0); SmallVector rhsOffsets(rhsReductionDims.size(), 0); - newLhs = rewriter.create(loc, newLhs, lhsOffsets); - newRhs = rewriter.create(loc, newRhs, rhsOffsets); + newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets); + newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets); std::optional result = createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), contractOp.getKind(), rewriter, isInt); @@ -1097,7 +1097,7 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. Location loc = op.getLoc(); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0; d < dimSize; ++d) { @@ -1110,7 +1110,7 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, lowMask = reshapeLoad(loc, mask, cast(mask.getType()), iterIndex, d, rewriter); - Operation *lowContract = rewriter.create( + Operation *lowContract = vector::ContractionOp::create(rewriter, loc, lhs, rhs, acc, lowAffine, lowIter); lowContract = maskOperation(rewriter, lowContract, lowMask); result = reshapeStore(loc, lowContract->getResult(0), result, resType, @@ -1161,8 +1161,8 @@ FailureOr ContractionOpLowering::lowerReduction( Value acc = op.getAcc(); Operation *reductionOp = - acc ? rewriter.create(loc, kind, m, acc) - : rewriter.create(loc, kind, m); + acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc) + : vector::ReductionOp::create(rewriter, loc, kind, m); return maskOperation(rewriter, reductionOp, mask)->getResult(0); } // Construct new iterator types and affine map array attribute. @@ -1186,7 +1186,7 @@ FailureOr ContractionOpLowering::lowerReduction( newMask = reshapeLoad(loc, mask, cast(mask.getType()), iterIndex, d, rewriter); - Operation *newContract = rewriter.create( + Operation *newContract = vector::ContractionOp::create(rewriter, loc, lhs, rhs, result, lowAffine, lowIter); result = maskOperation(rewriter, newContract, newMask)->getResult(0); } @@ -1240,7 +1240,7 @@ class OuterProductOpLowering : public OpRewritePattern { if (!rhsType) { // Special case: AXPY operation. - Value b = rewriter.create(loc, lhsType, op.getRhs()); + Value b = vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs()); std::optional mult = createContractArithOp( loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); if (!mult.has_value()) @@ -1249,23 +1249,23 @@ class OuterProductOpLowering : public OpRewritePattern { return success(); } - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { - Value x = rewriter.create(loc, op.getLhs(), d); - Value a = rewriter.create(loc, rhsType, x); + Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d); + Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, acc, d); + r = vector::ExtractOp::create(rewriter, loc, acc, d); Value extrMask; if (mask) - extrMask = rewriter.create(loc, mask, d); + extrMask = vector::ExtractOp::create(rewriter, loc, mask, d); std::optional m = createContractArithOp( loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); - result = rewriter.create(loc, *m, result, d); + result = vector::InsertOp::create(rewriter, loc, *m, result, d); } rewriter.replaceOp(rootOp, result); @@ -1335,7 +1335,7 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( Value lhs = op.getLhs(); auto lhsMap = op.getIndexingMapsArray()[0]; if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) - lhs = rew.create(loc, lhs, ArrayRef{1, 0}); + lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef{1, 0}); else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) return failure(); @@ -1343,7 +1343,7 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( Value rhs = op.getRhs(); auto rhsMap = op.getIndexingMapsArray()[1]; if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) - rhs = rew.create(loc, rhs, ArrayRef{1, 0}); + rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef{1, 0}); else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) return failure(); @@ -1356,15 +1356,15 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( Type flattenedLHSType = VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - lhs = rew.create(loc, flattenedLHSType, lhs); + lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs); Type flattenedRHSType = VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - rhs = rew.create(loc, flattenedRHSType, rhs); + rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs); - Value mul = rew.create(loc, lhs, rhs, lhsRows, lhsColumns, + Value mul = vector::MatmulOp::create(rew, loc, lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rew.create( + mul = vector::ShapeCastOp::create(rew, loc, VectorType::get({lhsRows, rhsColumns}, getElementTypeOrSelf(op.getAcc().getType())), @@ -1373,15 +1373,15 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( // ACC must be C(m, n) or C(n, m). auto accMap = op.getIndexingMapsArray()[2]; if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) - mul = rew.create(loc, mul, ArrayRef{1, 0}); + mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef{1, 0}); else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) llvm_unreachable("invalid contraction semantics"); Value res = isa(elementType) - ? static_cast(rew.create(loc, op.getAcc(), mul)) + ? static_cast(arith::AddIOp::create(rew, loc, op.getAcc(), mul)) : static_cast( - rew.create(loc, op.getAcc(), mul)); + arith::AddFOp::create(rew, loc, op.getAcc(), mul)); return res; } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 3000204c8ce17..c586059ca55bc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -75,7 +75,7 @@ struct UnrollGather : OpRewritePattern { Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resultTy, rewriter.getZeroAttr(resultTy)); VectorType subTy = VectorType::Builder(resultTy).dropDim(0); @@ -84,16 +84,16 @@ struct UnrollGather : OpRewritePattern { int64_t thisIdx[1] = {i}; Value indexSubVec = - rewriter.create(loc, indexVec, thisIdx); + vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); Value maskSubVec = - rewriter.create(loc, maskVec, thisIdx); + vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = - rewriter.create(loc, passThruVec, thisIdx); - Value subGather = rewriter.create( + vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); + Value subGather = vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, passThruSubVec); result = - rewriter.create(loc, subGather, result, thisIdx); + vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); } rewriter.replaceOp(op, result); @@ -159,22 +159,22 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { // 1. Collapse the input memref so that it's "flat". SmallVector reassoc = {{0, 1}}; - Value collapsed = rewriter.create( + Value collapsed = memref::CollapseShapeOp::create(rewriter, op.getLoc(), subview.getSource(), reassoc); // 2. Generate new gather indices that will model the // strided access. IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); VectorType vType = op.getIndexVec().getType(); - Value mulCst = rewriter.create( + Value mulCst = arith::ConstantOp::create(rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); Value newIdxs = - rewriter.create(op.getLoc(), op.getIndexVec(), mulCst); + arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst); // 3. Create an updated gather op with the collapsed input memref and the // updated indices. - Value newGather = rewriter.create( + Value newGather = vector::GatherOp::create(rewriter, op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(), newIdxs, op.getMask(), op.getPassThru()); rewriter.replaceOp(op, newGather); @@ -229,8 +229,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern { for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) { int64_t thisIdx[1] = {i}; Value condition = - rewriter.create(loc, condMask, thisIdx); - Value index = rewriter.create(loc, indexVec, thisIdx); + vector::ExtractOp::create(rewriter, loc, condMask, thisIdx); + Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); baseOffsets.back() = rewriter.createOrFold(loc, lastBaseOffset, index); @@ -240,19 +240,19 @@ struct Gather1DToConditionalLoads : OpRewritePattern { // `vector.load` does not support scalar result; emit a vector load // and extract the single result instead. Value load = - b.create(loc, elemVecTy, base, baseOffsets); + vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets); int64_t zeroIdx[1] = {0}; - extracted = b.create(loc, load, zeroIdx); + extracted = vector::ExtractOp::create(b, loc, load, zeroIdx); } else { - extracted = b.create(loc, base, baseOffsets); + extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets); } Value newResult = - b.create(loc, extracted, result, thisIdx); - b.create(loc, newResult); + vector::InsertOp::create(b, loc, extracted, result, thisIdx); + scf::YieldOp::create(b, loc, newResult); }; auto passThruBuilder = [result](OpBuilder &b, Location loc) { - b.create(loc, result); + scf::YieldOp::create(b, loc, result); }; result = diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index cab0f213b14a9..b0affbb699d9e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -60,14 +60,14 @@ class UnrollInterleaveOp final : public OpRewritePattern { return failure(); auto loc = op.getLoc(); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resultType, rewriter.getZeroAttr(resultType)); for (auto position : *unrollIterator) { - Value extractLhs = rewriter.create(loc, op.getLhs(), position); - Value extractRhs = rewriter.create(loc, op.getRhs(), position); + Value extractLhs = ExtractOp::create(rewriter, loc, op.getLhs(), position); + Value extractRhs = ExtractOp::create(rewriter, loc, op.getRhs(), position); Value interleave = - rewriter.create(loc, extractLhs, extractRhs); - result = rewriter.create(loc, interleave, result, position); + InterleaveOp::create(rewriter, loc, extractLhs, extractRhs); + result = InsertOp::create(rewriter, loc, interleave, result, position); } rewriter.replaceOp(op, result); @@ -123,19 +123,19 @@ class UnrollDeinterleaveOp final return failure(); auto loc = op.getLoc(); - Value emptyResult = rewriter.create( + Value emptyResult = arith::ConstantOp::create(rewriter, loc, resultType, rewriter.getZeroAttr(resultType)); Value evenResult = emptyResult; Value oddResult = emptyResult; for (auto position : *unrollIterator) { auto extractSrc = - rewriter.create(loc, op.getSource(), position); + vector::ExtractOp::create(rewriter, loc, op.getSource(), position); auto deinterleave = - rewriter.create(loc, extractSrc); - evenResult = rewriter.create( + vector::DeinterleaveOp::create(rewriter, loc, extractSrc); + evenResult = vector::InsertOp::create(rewriter, loc, deinterleave.getRes1(), evenResult, position); - oddResult = rewriter.create(loc, deinterleave.getRes2(), + oddResult = vector::InsertOp::create(rewriter, loc, deinterleave.getRes2(), oddResult, position); } rewriter.replaceOp(op, ValueRange{evenResult, oddResult}); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index ba21092d2af3c..edd840d8fbc7b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -67,19 +67,19 @@ class CreateMaskOpLowering : public OpRewritePattern { Value idx = op.getOperand(0); VectorType lowType = VectorType::Builder(dstType).dropDim(0); - Value trueVal = rewriter.create( + Value trueVal = vector::CreateMaskOp::create(rewriter, loc, lowType, op.getOperands().drop_front()); - Value falseVal = rewriter.create( + Value falseVal = arith::ConstantOp::create(rewriter, loc, lowType, rewriter.getZeroAttr(lowType)); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, dstType, rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < dim; d++) { Value bnd = - rewriter.create(loc, rewriter.getIndexAttr(d)); - Value val = rewriter.create(loc, arith::CmpIPredicate::slt, + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(d)); + Value val = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, bnd, idx); - Value sel = rewriter.create(loc, val, trueVal, falseVal); - result = rewriter.create(loc, sel, result, d); + Value sel = arith::SelectOp::create(rewriter, loc, val, trueVal, falseVal); + result = vector::InsertOp::create(rewriter, loc, sel, result, d); } rewriter.replaceOp(op, result); return success(); @@ -146,12 +146,12 @@ class ConstantMaskOpLowering : public OpRewritePattern { op, "Cannot unroll leading scalable dim in dstType"); VectorType lowType = VectorType::Builder(dstType).dropDim(0); - Value trueVal = rewriter.create( + Value trueVal = vector::ConstantMaskOp::create(rewriter, loc, lowType, dimSizes.drop_front()); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, dstType, rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < trueDimSize; d++) - result = rewriter.create(loc, trueVal, result, d); + result = vector::InsertOp::create(rewriter, loc, trueVal, result, d); rewriter.replaceOp(op, result); return success(); @@ -261,7 +261,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern { PatternRewriter &rewriter) const override { Value passthru = maskingOp.hasPassthru() ? maskingOp.getPassthru() - : rewriter.create( + : arith::ConstantOp::create(rewriter, gatherOp.getLoc(), rewriter.getZeroAttr(gatherOp.getVectorType())); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index ce524b259d8d4..c5631ba00bbc6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -103,12 +103,12 @@ class InnerOuterDimReductionConversion // If masked, transpose the original mask. Value transposedMask; if (maskableOp.isMasked()) { - transposedMask = rewriter.create( + transposedMask = vector::TransposeOp::create(rewriter, loc, maskableOp.getMaskingOp().getMask(), indices); } // Transpose reduction source. - auto transposeOp = rewriter.create(loc, src, indices); + auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { if (useInnerDimsForReduction) @@ -117,7 +117,7 @@ class InnerOuterDimReductionConversion reductionMask[i] = true; } - Operation *newMultiRedOp = rewriter.create( + Operation *newMultiRedOp = vector::MultiDimReductionOp::create(rewriter, multiReductionOp.getLoc(), transposeOp.getResult(), multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); newMultiRedOp = @@ -256,13 +256,13 @@ class ReduceMultiDimReductionRank vectorShape, llvm::cast(vectorMask.getType()).getElementType()); newVectorMask = - rewriter.create(loc, maskCastedType, vectorMask); + vector::ShapeCastOp::create(rewriter, loc, maskCastedType, vectorMask); } auto castedType = VectorType::get( vectorShape, multiReductionOp.getSourceVectorType().getElementType(), scalableDims); - Value cast = rewriter.create( + Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType, multiReductionOp.getSource()); Value acc = multiReductionOp.getAcc(); @@ -271,11 +271,11 @@ class ReduceMultiDimReductionRank {flattenedParallelDim}, multiReductionOp.getSourceVectorType().getElementType(), /*scalableDims=*/{isParallelDimScalable}); - acc = rewriter.create(loc, accType, acc); + acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc); } // 6. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. - Operation *newMultiDimRedOp = rewriter.create( + Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(rewriter, loc, cast, acc, mask, multiReductionOp.getKind()); newMultiDimRedOp = mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); @@ -339,11 +339,11 @@ struct TwoDimMultiReductionToElementWise Value result = multiReductionOp.getAcc(); for (int64_t i = 0; i < srcShape[0]; i++) { - auto operand = rewriter.create( + auto operand = vector::ExtractOp::create(rewriter, loc, multiReductionOp.getSource(), i); Value extractMask = nullptr; if (mask) { - extractMask = rewriter.create(loc, mask, i); + extractMask = vector::ExtractOp::create(rewriter, loc, mask, i); } result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand, @@ -383,27 +383,27 @@ struct TwoDimMultiReductionToReduction } auto loc = multiReductionOp.getLoc(); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, multiReductionOp.getDestType(), rewriter.getZeroAttr(multiReductionOp.getDestType())); int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; for (int i = 0; i < outerDim; ++i) { - auto v = rewriter.create( + auto v = vector::ExtractOp::create(rewriter, loc, multiReductionOp.getSource(), ArrayRef{i}); - auto acc = rewriter.create( + auto acc = vector::ExtractOp::create(rewriter, loc, multiReductionOp.getAcc(), ArrayRef{i}); - Operation *reductionOp = rewriter.create( + Operation *reductionOp = vector::ReductionOp::create(rewriter, loc, multiReductionOp.getKind(), v, acc); // If masked, slice the mask and mask the new reduction operation. if (maskableOp.isMasked()) { - Value mask = rewriter.create( + Value mask = vector::ExtractOp::create(rewriter, loc, maskableOp.getMaskingOp().getMask(), ArrayRef{i}); reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); } - result = rewriter.create(loc, reductionOp->getResult(0), + result = vector::InsertOp::create(rewriter, loc, reductionOp->getResult(0), result, i); } @@ -459,9 +459,9 @@ struct OneDimMultiReductionToTwoDim SmallVector reductionMask{false, true}; /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) - Value cast = rewriter.create( + Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType, multiReductionOp.getSource()); - Value castAcc = rewriter.create( + Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType, multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { @@ -470,10 +470,10 @@ struct OneDimMultiReductionToTwoDim ArrayRef{1, maskType.getShape().back()}, maskType.getElementType(), ArrayRef{false, maskType.getScalableDims().back()}); - castMask = rewriter.create(loc, castMaskType, mask); + castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask); } - Operation *newOp = rewriter.create( + Operation *newOp = vector::MultiDimReductionOp::create(rewriter, loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); newOp = vector::maskOperation(rewriter, newOp, castMask); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index a1f67bd0e9ed3..a6985b96b59df 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -122,7 +122,7 @@ struct ScanToArithOps : public OpRewritePattern { return failure(); VectorType resType = VectorType::get(destShape, elType); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); bool inclusive = scanOp.getInclusive(); @@ -144,7 +144,7 @@ struct ScanToArithOps : public OpRewritePattern { for (int i = 0; i < destShape[reductionDim]; i++) { offsets[reductionDim] = i; ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); - Value input = rewriter.create( + Value input = vector::ExtractStridedSliceOp::create(rewriter, loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, scanStrides); Value output; @@ -154,10 +154,10 @@ struct ScanToArithOps : public OpRewritePattern { } else { if (initialValueRank == 0) { // ShapeCastOp cannot handle 0-D vectors - output = rewriter.create( + output = vector::BroadcastOp::create(rewriter, loc, input.getType(), scanOp.getInitialValue()); } else { - output = rewriter.create( + output = vector::ShapeCastOp::create(rewriter, loc, input.getType(), scanOp.getInitialValue()); } } @@ -166,7 +166,7 @@ struct ScanToArithOps : public OpRewritePattern { output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(), lastOutput, y); } - result = rewriter.create( + result = vector::InsertStridedSliceOp::create(rewriter, loc, output, result, offsets, strides); lastOutput = output; lastInput = input; @@ -174,11 +174,11 @@ struct ScanToArithOps : public OpRewritePattern { Value reduction; if (initialValueRank == 0) { - Value v = rewriter.create(loc, lastOutput, 0); + Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0); reduction = - rewriter.create(loc, initialValueType, v); + vector::BroadcastOp::create(rewriter, loc, initialValueType, v); } else { - reduction = rewriter.create(loc, initialValueType, + reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType, lastOutput); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 39c16fab21c4e..69d3666a13879 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -137,10 +137,10 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { const int64_t resultLeading = delta > 0 ? 0 : -delta; const Value source = shapeCast.getSource(); - const Value poison = rewriter.create(loc, resultType); - const Value extracted = rewriter.create( + const Value poison = ub::PoisonOp::create(rewriter, loc, resultType); + const Value extracted = vector::ExtractOp::create(rewriter, loc, source, SmallVector(sourceLeading, 0)); - const Value result = rewriter.create( + const Value result = vector::InsertOp::create(rewriter, loc, extracted, poison, SmallVector(resultLeading, 0)); rewriter.replaceOp(shapeCast, result); @@ -171,13 +171,13 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { SmallVector extractIndex(sourceDim, 0); SmallVector insertIndex(resultDim, 0); - Value result = rewriter.create(loc, resultType); + Value result = ub::PoisonOp::create(rewriter, loc, resultType); for (int i = 0; i < nSlices; ++i) { Value extracted = - rewriter.create(loc, source, extractIndex); + vector::ExtractOp::create(rewriter, loc, source, extractIndex); - result = rewriter.create(loc, extracted, result, + result = vector::InsertOp::create(rewriter, loc, extracted, result, insertIndex); inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex); @@ -276,9 +276,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { Value extracted = {}; Value extractedStrided = {}; Value insertedSlice = {}; - Value result = rewriter.create(loc, resultType); + Value result = ub::PoisonOp::create(rewriter, loc, resultType); const Value partResult = - rewriter.create(loc, insertStridedType); + ub::PoisonOp::create(rewriter, loc, insertStridedType); for (size_t i = 0; i < nAtomicSlices; ++i) { @@ -288,14 +288,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { // vector.extract if (extractStridedPhase == 0) { extracted = - rewriter.create(loc, source, extractIndex); + vector::ExtractOp::create(rewriter, loc, source, extractIndex); inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim), extractIndex); } // vector.extract_strided_slice extractOffsets[0] = extractStridedPhase * greatestCommonDivisor; - extractedStrided = rewriter.create( + extractedStrided = vector::ExtractStridedSliceOp::create(rewriter, loc, extracted, extractOffsets, atomicShape, sizes); // vector.insert_strided_slice @@ -303,12 +303,12 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { insertedSlice = partResult; } insertOffsets[0] = insertStridedPhase * greatestCommonDivisor; - insertedSlice = rewriter.create( + insertedSlice = vector::InsertStridedSliceOp::create(rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes); // vector.insert if (insertStridedPhase + 1 == insertPeriod) { - result = rewriter.create(loc, insertedSlice, result, + result = vector::InsertOp::create(rewriter, loc, insertedSlice, result, insertIndex); inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim), insertIndex); @@ -394,7 +394,7 @@ class ScalableShapeCastOpRewritePattern auto extractionVectorType = VectorType::get( {minExtractionSize}, sourceVectorType.getElementType(), {true}); - Value result = rewriter.create(loc, resultVectorType); + Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType); SmallVector srcIdx(srcRank, 0); SmallVector resIdx(resRank, 0); @@ -406,7 +406,7 @@ class ScalableShapeCastOpRewritePattern // 1. Extract a scalable subvector from the source vector. if (!currentSourceScalableVector) { if (srcRank != 1) { - currentSourceScalableVector = rewriter.create( + currentSourceScalableVector = vector::ExtractOp::create(rewriter, loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); } else { currentSourceScalableVector = op.getSource(); @@ -414,7 +414,7 @@ class ScalableShapeCastOpRewritePattern } Value sourceSubVector = currentSourceScalableVector; if (minExtractionSize < minSourceTrailingSize) { - sourceSubVector = rewriter.create( + sourceSubVector = vector::ScalableExtractOp::create(rewriter, loc, extractionVectorType, sourceSubVector, srcIdx.back()); } @@ -423,14 +423,14 @@ class ScalableShapeCastOpRewritePattern if (minExtractionSize == minResultTrailingSize) { currentResultScalableVector = sourceSubVector; } else if (resRank != 1) { - currentResultScalableVector = rewriter.create( + currentResultScalableVector = vector::ExtractOp::create(rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back()); } else { currentResultScalableVector = result; } } if (minExtractionSize < minResultTrailingSize) { - currentResultScalableVector = rewriter.create( + currentResultScalableVector = vector::ScalableInsertOp::create(rewriter, loc, sourceSubVector, currentResultScalableVector, resIdx.back()); } @@ -439,7 +439,7 @@ class ScalableShapeCastOpRewritePattern currentResultScalableVector != result) { // Finished row of result. Insert complete scalable vector into result // (n-D) vector. - result = rewriter.create( + result = vector::InsertOp::create(rewriter, loc, currentResultScalableVector, result, llvm::ArrayRef(resIdx).drop_back()); currentResultScalableVector = {}; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 5b81d0d33d484..fe02cd5656571 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -47,7 +47,7 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, originalVecType.getScalableDims().end()); VectorType newVecType = VectorType::get( newShape, originalVecType.getElementType(), newScalableDims); - return builder.create(loc, newVecType, vec); + return vector::BroadcastOp::create(builder, loc, newVecType, vec); } /// Extend the rank of a vector Value by `addedRanks` by adding inner unit @@ -62,7 +62,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, permutation.push_back(i); for (int64_t i = 0; i < addedRank; ++i) permutation.push_back(i); - return builder.create(loc, broadcasted, permutation); + return vector::TransposeOp::create(builder, loc, broadcasted, permutation); } //===----------------------------------------------------------------------===// @@ -138,7 +138,7 @@ struct TransferReadPermutationLowering // Generate new transfer_read operation. VectorType newReadType = VectorType::get( newVectorShape, op.getVectorType().getElementType(), newScalableDims); - Value newRead = rewriter.create( + Value newRead = vector::TransferReadOp::create(rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); @@ -209,11 +209,11 @@ struct TransferWritePermutationLowering inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation); // Generate new transfer_write operation. - Value newVec = rewriter.create( + Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(), op.getVector(), indices); auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); - auto newWrite = rewriter.create( + auto newWrite = vector::TransferWriteOp::create(rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); if (newWrite.hasPureTensorSemantics()) @@ -299,7 +299,7 @@ struct TransferWriteNonPermutationLowering newInBoundsValues.push_back(op.isDimInBounds(i)); } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); - auto newWrite = rewriter.create( + auto newWrite = vector::TransferWriteOp::create(rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), newMask, newInBoundsAttr); if (newWrite.hasPureTensorSemantics()) @@ -370,7 +370,7 @@ struct TransferOpReduceRank ? rewriter.getArrayAttr( op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) : ArrayAttr(); - Value newRead = rewriter.create( + Value newRead = vector::TransferReadOp::create(rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); @@ -471,20 +471,20 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = rewriter.create( + Value fill = vector::SplatOp::create(rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); - res = rewriter.create( + res = vector::MaskedLoadOp::create(rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), read.getIndices(), read.getMask(), fill); } else { - res = rewriter.create(read.getLoc(), + res = vector::LoadOp::create(rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), read.getIndices()); } // Insert a broadcasting op if required. if (!broadcastedDims.empty()) - res = rewriter.create( + res = vector::BroadcastOp::create(rewriter, read.getLoc(), read.getVectorType(), res->getResult(0)); return res->getResult(0); } @@ -569,11 +569,11 @@ struct TransferWriteToVectorStoreLowering << write; }); - rewriter.create( + vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(), write.getIndices(), write.getMask(), write.getVector()); } else { - rewriter.create(write.getLoc(), write.getVector(), + vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(), write.getBase(), write.getIndices()); } // There's no return value for StoreOps. Use Value() to signal success to diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 732e316c93381..e9591bfcf6feb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -80,7 +80,7 @@ getUnpackShufflePermFor128Lane(ArrayRef vals, int numBits) { static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - return b.create( + return vector::ShuffleOp::create(b, v1, v2, getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits)); } @@ -94,7 +94,7 @@ static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - return b.create( + return vector::ShuffleOp::create(b, v1, v2, getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3}, numBits)); @@ -109,7 +109,7 @@ static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - auto shuffle = b.create( + auto shuffle = vector::ShuffleOp::create(b, v1, v2, getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits)); return shuffle; @@ -124,7 +124,7 @@ static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits) { int numElem = numBits / 32; - return b.create( + return vector::ShuffleOp::create(b, v1, v2, getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3}, numBits)); @@ -181,7 +181,7 @@ static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, appendToMask(0, b23); appendToMask(16, b45); appendToMask(16, b67); - return b.create(v1, v2, shuffleMask); + return vector::ShuffleOp::create(b, v1, v2, shuffleMask); } /// Lowers the value to a vector.shuffle op. The `source` is expected to be a @@ -192,7 +192,7 @@ static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { for (int64_t j = 0; j < n; ++j) for (int64_t i = 0; i < m; ++i) mask.push_back(i * n + j); - return b.create(source.getLoc(), source, source, mask); + return vector::ShuffleOp::create(b, source.getLoc(), source, source, mask); } /// Lowers the value to a sequence of vector.shuffle ops. The `source` is @@ -284,9 +284,9 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, auto reshInputType = VectorType::get( {m, n}, cast(source.getType()).getElementType()); - Value res = b.create(reshInputType); + Value res = ub::PoisonOp::create(b, reshInputType); for (int64_t i = 0; i < m; ++i) - res = b.create(vs[i], res, i); + res = vector::InsertOp::create(b, vs[i], res, i); return res; } @@ -335,10 +335,10 @@ class TransposeOpLowering : public OpRewritePattern { Type flattenedType = VectorType::get(resType.getNumElements(), resType.getElementType()); auto matrix = - rewriter.create(loc, flattenedType, input); + vector::ShapeCastOp::create(rewriter, loc, flattenedType, input); auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); - Value trans = rewriter.create( + Value trans = vector::FlatTransposeOp::create(rewriter, loc, flattenedType, matrix, rows, columns); rewriter.replaceOpWithNewOp(op, resType, trans); return success(); @@ -359,7 +359,7 @@ class TransposeOpLowering : public OpRewritePattern { // of the leftmost transposed dimensions. We traverse every transpose // element using a linearized index that we delinearize to generate the // appropriate indices for the extract/insert operations. - Value result = rewriter.create(loc, resType); + Value result = ub::PoisonOp::create(rewriter, loc, resType); int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); for (int64_t linearIdx = 0; linearIdx < numTransposedElements; @@ -482,14 +482,14 @@ class TransposeOp2DToShuffleLowering Location loc = op.getLoc(); auto flattenedType = VectorType::get({n * m}, srcType.getElementType()); auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); - auto reshInput = rewriter.create(loc, flattenedType, + auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType, op.getVector()); Value res; if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 && m == 16 && n == 16) { reshInput = - rewriter.create(loc, reshInputType, reshInput); + vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput); res = transposeToShuffle16x16(rewriter, reshInput, m, n); } else { // Fallback to shuffle on 1D approach. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index fb99e22c77ea0..b555c3f6d9f85 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -114,7 +114,7 @@ struct DistributedLoadStoreHelper { "preregistered sequential value."); // Scalar case can directly use memref.store. if (!isa(val.getType())) - return b.create(loc, val, buffer, zero); + return memref::StoreOp::create(b, loc, val, buffer, zero); // Vector case must use vector::TransferWriteOp which will later lower to // vector.store of memref.store depending on further lowerings. @@ -127,7 +127,7 @@ struct DistributedLoadStoreHelper { } } SmallVector inBounds(indices.size(), true); - return b.create( + return vector::TransferWriteOp::create(b, loc, val, buffer, indices, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -156,7 +156,7 @@ struct DistributedLoadStoreHelper { // Scalar case can directly use memref.store. if (!isa(type)) - return b.create(loc, buffer, zero); + return memref::LoadOp::create(b, loc, buffer, zero); // Other cases must be vector atm. // Vector case must use vector::TransferReadOp which will later lower to @@ -172,7 +172,7 @@ struct DistributedLoadStoreHelper { } } SmallVector inBounds(indices.size(), true); - return b.create( + return vector::TransferReadOp::create(b, loc, cast(type), buffer, indices, /*padding=*/std::nullopt, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -243,10 +243,10 @@ struct WarpOpToScfIfPattern : public WarpDistributionPattern { rewriter.setInsertionPoint(warpOp); // Step 1: Create scf.if op. - Value c0 = rewriter.create(loc, 0); - Value isLane0 = rewriter.create( + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value isLane0 = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); - auto ifOp = rewriter.create(loc, isLane0, + auto ifOp = scf::IfOp::create(rewriter, loc, isLane0, /*withElseRegion=*/false); rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); @@ -325,7 +325,7 @@ struct WarpOpToScfIfPattern : public WarpDistributionPattern { // Step 7. Delete terminator and add empty scf.yield. rewriter.eraseOp(yieldOp); rewriter.setInsertionPointToEnd(ifOp.thenBlock()); - rewriter.create(yieldLoc); + scf::YieldOp::create(rewriter, yieldLoc); // Compute replacements for WarpOp results. rewriter.replaceOp(warpOp, replacements); @@ -512,7 +512,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); // Create a second warp op that contains only writeOp. - auto secondWarpOp = rewriter.create( + auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); Block &body = secondWarpOp.getBodyRegion().front(); rewriter.setInsertionPointToStart(&body); @@ -521,7 +521,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { newWriteOp.getValueToStoreMutable().assign( newWarpOp.getResult(newRetIndices[0])); rewriter.eraseOp(writeOp); - rewriter.create(newWarpOp.getLoc()); + gpu::YieldOp::create(rewriter, newWarpOp.getLoc()); return success(); } @@ -698,7 +698,7 @@ struct WarpOpConstant : public WarpDistributionPattern { cast(warpOp.getResult(operandIndex).getType()), scalarAttr); Location loc = warpOp.getLoc(); rewriter.setInsertionPointAfter(warpOp); - Value distConstant = rewriter.create(loc, newAttr); + Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr); rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); rewriter.finalizeOpModification(warpOp); return success(); @@ -823,7 +823,7 @@ struct WarpOpTransferRead : public WarpDistributionPattern { Value newMask = hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) : Value(); - auto newRead = rewriter.create( + auto newRead = vector::TransferReadOp::create(rewriter, read.getLoc(), distributedVal.getType(), read.getBase(), newIndices, read.getPermutationMapAttr(), newPadding, newMask, read.getInBoundsAttr()); @@ -965,7 +965,7 @@ struct WarpOpBroadcast : public WarpDistributionPattern { WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value broadcasted = rewriter.create( + Value broadcasted = vector::BroadcastOp::create(rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), broadcasted); @@ -1008,7 +1008,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value newCast = rewriter.create( + Value newCast = vector::ShapeCastOp::create(rewriter, oldCastOp.getLoc(), castResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); @@ -1091,7 +1091,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern { } auto newMask = - rewriter.create(loc, distType, newOperands); + vector::CreateMaskOp::create(rewriter, loc, distType, newOperands); rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); rewriter.finalizeOpModification(warpOp); return success(); @@ -1182,7 +1182,7 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern { Value distributedDest = newWarpOp->getResult(newRetIndices[1]); // Create a new insert strided slice op that inserts distributed source into // distributed dest. - Value newInsert = rewriter.create( + Value newInsert = vector::InsertStridedSliceOp::create(rewriter, insertOp.getLoc(), distributedDest.getType(), distributedSource, distributedDest, insertOp.getOffsets(), insertOp.getStrides()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert); @@ -1277,7 +1277,7 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern { // Create a new extract strided slice op that extracts from the // distributed vector. Value distributedVec = newWarpOp->getResult(newRetIndices[0]); - Value newExtract = rewriter.create( + Value newExtract = vector::ExtractStridedSliceOp::create(rewriter, extractOp.getLoc(), distributedType, distributedVec, extractOp.getOffsets(), ArrayAttr::get(rewriter.getContext(), distributedSizes), @@ -1323,7 +1323,7 @@ struct WarpOpExtract : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. - Value newExtract = rewriter.create( + Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); @@ -1352,7 +1352,7 @@ struct WarpOpExtract : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); // Extract from distributed vector. - Value newExtract = rewriter.create( + Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec, extractOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); @@ -1422,7 +1422,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern { Value newExtract; SmallVector indices(extractSrcType.getRank(), 0); newExtract = - rewriter.create(loc, distributedVec, indices); + vector::ExtractOp::create(rewriter, loc, distributedVec, indices); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); @@ -1442,11 +1442,11 @@ struct WarpOpExtractScalar : public WarpDistributionPattern { // Extract at position: pos % elementsPerLane Value newPos = elementsPerLane == 1 - ? rewriter.create(loc, 0).getResult() + ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult() : affine::makeComposedAffineApply(rewriter, loc, sym0 % elementsPerLane, pos); Value extracted = - rewriter.create(loc, distributedVec, newPos); + vector::ExtractOp::create(rewriter, loc, distributedVec, newPos); // Shuffle the extracted value to all lanes. Value shuffled = warpShuffleFromIdxFn( @@ -1535,7 +1535,7 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { if (pos) { indices.push_back(pos); } - newInsert = rewriter.create(loc, newSource, + newInsert = vector::InsertOp::create(rewriter, loc, newSource, distributedVec, indices); // Broadcast: Simply move the vector.insert op out. rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), @@ -1552,7 +1552,7 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { // Insert position: pos % elementsPerLane OpFoldResult newPos = affine::makeComposedFoldedAffineApply( rewriter, loc, sym0 % elementsPerLane, pos); - Value isInsertingLane = rewriter.create( + Value isInsertingLane = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); Value newResult = rewriter @@ -1560,13 +1560,13 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { loc, isInsertingLane, /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { - Value newInsert = builder.create( + Value newInsert = vector::InsertOp::create(builder, loc, newSource, distributedVec, newPos); - builder.create(loc, newInsert); + scf::YieldOp::create(builder, loc, newInsert); }, /*elseBuilder=*/ [&](OpBuilder &builder, Location loc) { - builder.create(loc, distributedVec); + scf::YieldOp::create(builder, loc, distributedVec); }) .getResult(0); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); @@ -1603,7 +1603,7 @@ struct WarpOpInsert : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); Value distributedDest = newWarpOp->getResult(newRetIndices[1]); - Value newResult = rewriter.create( + Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); @@ -1653,7 +1653,7 @@ struct WarpOpInsert : public WarpDistributionPattern { Value newResult; if (distrSrcDim >= 0) { // Every lane inserts a small piece. - newResult = rewriter.create( + newResult = vector::InsertOp::create(rewriter, loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); } else { // One lane inserts the entire source vector. @@ -1661,19 +1661,19 @@ struct WarpOpInsert : public WarpDistributionPattern { SmallVector pos = insertOp.getMixedPosition(); SmallVector newPos = getAsIntegers(pos); // tid of inserting lane: pos / elementsPerLane - Value insertingLane = rewriter.create( + Value insertingLane = arith::ConstantIndexOp::create(rewriter, loc, newPos[distrDestDim] / elementsPerLane); - Value isInsertingLane = rewriter.create( + Value isInsertingLane = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); // Insert position: pos % elementsPerLane newPos[distrDestDim] %= elementsPerLane; auto insertingBuilder = [&](OpBuilder &builder, Location loc) { - Value newInsert = builder.create( + Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc, distributedDest, newPos); - builder.create(loc, newInsert); + scf::YieldOp::create(builder, loc, newInsert); }; auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { - builder.create(loc, distributedDest); + scf::YieldOp::create(builder, loc, distributedDest); }; newResult = rewriter .create(loc, isInsertingLane, @@ -1802,7 +1802,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Create a new for op outside the region with a WarpExecuteOnLane0Op // region inside. - auto newForOp = rewriter.create( + auto newForOp = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newOperands); rewriter.setInsertionPointToStart(newForOp.getBody()); @@ -1817,7 +1817,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { argIndexMapping[escapingValues[i]] = warpInputType.size(); warpInputType.push_back(inputTypes[i]); } - auto innerWarp = rewriter.create( + auto innerWarp = WarpExecuteOnLane0Op::create(rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), newWarpOp.getWarpSize(), warpInput, warpInputType); @@ -1833,10 +1833,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern { rewriter.eraseOp(forOp.getBody()->getTerminator()); rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); rewriter.setInsertionPointToEnd(innerWarp.getBody()); - rewriter.create(innerWarp.getLoc(), yieldOperands); + gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); if (!innerWarp.getResults().empty()) - rewriter.create(forOp.getLoc(), innerWarp.getResults()); + scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); rewriter.eraseOp(forOp); // Replace the warpOp result coming from the original ForOp. for (const auto &res : llvm::enumerate(resultIdx)) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 067d4e3491391..e9718a7795fc5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -77,7 +77,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim Location loc = extractOp.getLoc(); - Value newSrcVector = rewriter.create( + Value newSrcVector = vector::ExtractOp::create(rewriter, loc, extractOp.getVector(), splatZero(dropCount)); // The offsets/sizes/strides attribute can have a less number of elements @@ -89,7 +89,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim auto newStrides = rewriter.getArrayAttr( extractOp.getStrides().getValue().drop_front(dropCount)); - auto newExtractOp = rewriter.create( + auto newExtractOp = vector::ExtractStridedSliceOp::create(rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); rewriter.replaceOpWithNewOp(extractOp, oldDstType, @@ -120,9 +120,9 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Trim leading one dimensions from both operands. Location loc = insertOp.getLoc(); - Value newSrcVector = rewriter.create( + Value newSrcVector = vector::ExtractOp::create(rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount)); - Value newDstVector = rewriter.create( + Value newDstVector = vector::ExtractOp::create(rewriter, loc, insertOp.getDest(), splatZero(dstDropCount)); auto newOffsets = rewriter.getArrayAttr( @@ -130,7 +130,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim auto newStrides = rewriter.getArrayAttr( insertOp.getStrides().getValue().take_back(newSrcType.getRank())); - auto newInsertOp = rewriter.create( + auto newInsertOp = vector::InsertStridedSliceOp::create(rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); rewriter.replaceOpWithNewOp(insertOp, oldDstType, @@ -169,10 +169,10 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { Value newSrcVector = insertOp.getValueToStore(); if (oldSrcRank != 0) { - newSrcVector = rewriter.create( + newSrcVector = vector::ExtractOp::create(rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount)); } - Value newDstVector = rewriter.create( + Value newDstVector = vector::ExtractOp::create(rewriter, loc, insertOp.getDest(), splatZero(dstDropCount)); // New position rank needs to be computed in two steps: (1) if destination @@ -187,7 +187,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { newPosition.resize(newDstType.getRank() - newSrcRank, rewriter.getI64IntegerAttr(0)); - auto newInsertOp = rewriter.create( + auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector, newDstVector, newPosition); rewriter.replaceOpWithNewOp(insertOp, oldDstType, @@ -209,9 +209,9 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, if (vector::isBroadcastableTo(newMaskType, oldMaskType) == BroadcastableToResult::Success) { int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); - return b.create(loc, mask, splatZero(dropDim)); + return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim)); } - return b.create(loc, newMaskType, mask); + return vector::ShapeCastOp::create(b, loc, newMaskType, mask); } // Turns vector.transfer_read on vector with leading 1 dimensions into @@ -259,7 +259,7 @@ struct CastAwayTransferReadLeadingOneDim newType, newMap, maskType); } - auto newRead = rewriter.create( + auto newRead = vector::TransferReadOp::create(rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(), AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); rewriter.replaceOpWithNewOp(read, oldType, newRead); @@ -306,7 +306,7 @@ struct CastAwayTransferWriteLeadingOneDim inBoundsAttr = rewriter.getArrayAttr( write.getInBoundsAttr().getValue().take_back(newType.getRank())); - auto newVector = rewriter.create( + auto newVector = vector::ExtractOp::create(rewriter, write.getLoc(), write.getVector(), splatZero(dropDim)); if (write.getMask()) { @@ -444,20 +444,20 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, // Extract if its a valid extraction, otherwise use the operand // without extraction. newOperands.push_back( - validExtract ? rewriter.create( + validExtract ? vector::ExtractOp::create(rewriter, loc, operands[it.index()], splatZero(dropDim)) : operands[it.index()]); } // Depending on whether this vector.contract is masked, the replacing Op // should either be a new vector.contract Op or vector.mask Op. - Operation *newOp = rewriter.create( + Operation *newOp = vector::ContractionOp::create(rewriter, loc, newOperands[0], newOperands[1], newOperands[2], rewriter.getAffineMapArrayAttr(newIndexingMaps), rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); if (maskingOp) { - auto newMask = rewriter.create(loc, maskingOp.getMask(), + auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(), splatZero(dropDim)); newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); @@ -519,7 +519,7 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { SmallVector newOperands; for (Value operand : op->getOperands()) { if (auto opVecType = dyn_cast(operand.getType())) { - newOperands.push_back(rewriter.create( + newOperands.push_back(vector::ExtractOp::create(rewriter, op->getLoc(), operand, splatZero(dropDim))); } else { newOperands.push_back(operand); @@ -559,7 +559,7 @@ struct CastAwayConstantMaskLeadingOneDim SmallVector newDimSizes = {flatLeadingSize}; newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); - auto newMask = rewriter.create( + auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(), newType, newDimSizes); rewriter.replaceOpWithNewOp(mask, oldType, newMask); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index 8cc7008d80b3e..3411e010f3499 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -65,26 +65,26 @@ struct VectorMaskedLoadOpConverter final Value base = maskedLoadOp.getBase(); Value iValue = maskedLoadOp.getPassThru(); auto indices = llvm::to_vector_of(maskedLoadOp.getIndices()); - Value one = rewriter.create( + Value one = arith::ConstantOp::create(rewriter, loc, indexType, IntegerAttr::get(indexType, 1)); for (int64_t i = 0; i < maskLength; ++i) { - auto maskBit = rewriter.create(loc, mask, i); + auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i); - auto ifOp = rewriter.create( + auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, [&](OpBuilder &builder, Location loc) { auto loadedValue = - builder.create(loc, base, indices); + memref::LoadOp::create(builder, loc, base, indices); auto combinedValue = - builder.create(loc, loadedValue, iValue, i); - builder.create(loc, combinedValue.getResult()); + vector::InsertOp::create(builder, loc, loadedValue, iValue, i); + scf::YieldOp::create(builder, loc, combinedValue.getResult()); }, [&](OpBuilder &builder, Location loc) { - builder.create(loc, iValue); + scf::YieldOp::create(builder, loc, iValue); }); iValue = ifOp.getResult(0); - indices.back() = rewriter.create(loc, indices.back(), one); + indices.back() = arith::AddIOp::create(rewriter, loc, indices.back(), one); } rewriter.replaceOp(maskedLoadOp, iValue); @@ -132,18 +132,18 @@ struct VectorMaskedStoreOpConverter final Value base = maskedStoreOp.getBase(); Value value = maskedStoreOp.getValueToStore(); auto indices = llvm::to_vector_of(maskedStoreOp.getIndices()); - Value one = rewriter.create( + Value one = arith::ConstantOp::create(rewriter, loc, indexType, IntegerAttr::get(indexType, 1)); for (int64_t i = 0; i < maskLength; ++i) { - auto maskBit = rewriter.create(loc, mask, i); + auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i); - auto ifOp = rewriter.create(loc, maskBit, /*else=*/false); + auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - auto extractedValue = rewriter.create(loc, value, i); - rewriter.create(loc, extractedValue, base, indices); + auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i); + memref::StoreOp::create(rewriter, loc, extractedValue, base, indices); rewriter.setInsertionPointAfter(ifOp); - indices.back() = rewriter.create(loc, indices.back(), one); + indices.back() = arith::AddIOp::create(rewriter, loc, indices.back(), one); } rewriter.eraseOp(maskedStoreOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 004beadc9ec7d..12b548bcee646 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -132,7 +132,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, SmallVector newMaskOperands(maskOperands.drop_back()); newMaskOperands.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex)); - return rewriter.create(loc, newMaskType, + return vector::CreateMaskOp::create(rewriter, loc, newMaskType, newMaskOperands); }) .Case( @@ -143,7 +143,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, int64_t &maskIndex = maskDimSizes.back(); maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, numSrcElemsPerDest); - return rewriter.create(loc, newMaskType, + return vector::ConstantMaskOp::create(rewriter, loc, newMaskType, maskDimSizes); }) .Case([&](auto constantOp) @@ -182,7 +182,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, } compressedMaskValues.push_back(combinedValue); } - return rewriter.create( + return arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(newMaskType, compressedMaskValues)); }); @@ -190,7 +190,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, return failure(); while (!extractOps.empty()) { - newMask = rewriter.create( + newMask = vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition()); extractOps.pop_back(); } @@ -258,7 +258,7 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, auto offsets = rewriter.getI64ArrayAttr({offset}); auto strides = rewriter.getI64ArrayAttr({1}); - return rewriter.create(loc, destVecTy, src, + return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src, dest, offsets, strides); } @@ -301,11 +301,11 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, for (int i = 0; i < numElemsToExtract; ++i) { Value extractLoc = (i == 0) ? dyn_cast(offset) - : rewriter.create( + : arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), dyn_cast(offset), - rewriter.create(loc, i)); - auto extractOp = rewriter.create(loc, src, extractLoc); - dest = rewriter.create(loc, extractOp, dest, i); + arith::ConstantIndexOp::create(rewriter, loc, i)); + auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc); + dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i); } return dest; } @@ -346,11 +346,11 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, for (int64_t i = 0; i < numElemsToInsert; ++i) { auto insertLoc = i == 0 ? destOffsetVal - : rewriter.create( + : arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), destOffsetVal, - rewriter.create(loc, i)); - auto extractOp = rewriter.create(loc, src, i); - dest = rewriter.create(loc, extractOp, dest, insertLoc); + arith::ConstantIndexOp::create(rewriter, loc, i)); + auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i); + dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc); } return dest; } @@ -369,10 +369,10 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Type containerElemTy) { auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() / emulatedElemTy.getIntOrFloatBitWidth(); - auto newLoad = rewriter.create( + auto newLoad = vector::LoadOp::create(rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); - return rewriter.create( + return vector::BitCastOp::create(rewriter, loc, VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem, emulatedElemTy), @@ -390,16 +390,16 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, upcastType.getNumElements() * upcastType.getElementTypeBitWidth() && "expected input and output number of bits to match"); if (trueValue.getType() != downcastType) { - trueValue = builder.create(loc, downcastType, trueValue); + trueValue = vector::BitCastOp::create(builder, loc, downcastType, trueValue); } if (falseValue.getType() != downcastType) { falseValue = - builder.create(loc, downcastType, falseValue); + vector::BitCastOp::create(builder, loc, downcastType, falseValue); } Value selectedType = - builder.create(loc, mask, trueValue, falseValue); + arith::SelectOp::create(builder, loc, mask, trueValue, falseValue); // Upcast the selected value to the new type. - return builder.create(loc, upcastType, selectedType); + return vector::BitCastOp::create(builder, loc, upcastType, selectedType); } /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a @@ -422,7 +422,7 @@ static void atomicRMW(OpBuilder &builder, Location loc, // Create an atomic load-modify-write region using // `memref.generic_atomic_rmw`. - auto atomicOp = builder.create( + auto atomicOp = memref::GenericAtomicRMWOp::create(builder, loc, linearizedMemref, ValueRange{storeIdx}); Value origValue = atomicOp.getCurrentValue(); @@ -432,7 +432,7 @@ static void atomicRMW(OpBuilder &builder, Location loc, // Load the original value from memory, and cast it to the original element // type. auto oneElemVecType = VectorType::get({1}, origValue.getType()); - Value origVecValue = builder.create( + Value origVecValue = vector::FromElementsOp::create(builder, loc, oneElemVecType, ValueRange{origValue}); // Construct the final masked value and yield it. @@ -440,8 +440,8 @@ static void atomicRMW(OpBuilder &builder, Location loc, downcastSelectAndUpcast(builder, loc, valueToStore.getType(), oneElemVecType, mask, valueToStore, origVecValue); auto scalarMaskedValue = - builder.create(loc, maskedValue, 0); - builder.create(loc, scalarMaskedValue); + vector::ExtractOp::create(builder, loc, maskedValue, 0); + memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue); } /// Generate a non-atomic read-modify-write sequence for storing to the emulated @@ -453,15 +453,15 @@ static void nonAtomicRMW(OpBuilder &builder, Location loc, auto oneElemVecType = VectorType::get({1}, linearizedMemref.getType().getElementType()); - Value origVecValue = builder.create( + Value origVecValue = vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex}); - origVecValue = builder.create(loc, valueToStore.getType(), + origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(), origVecValue); Value maskedValue = downcastSelectAndUpcast(builder, loc, valueToStore.getType(), oneElemVecType, mask, valueToStore, origVecValue); - builder.create(loc, maskedValue, linearizedMemref, + vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref, linearizedIndex); } @@ -489,7 +489,7 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 && "vector element must be a valid sub-byte type"); auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth(); - auto emptyByteVector = rewriter.create( + auto emptyByteVector = arith::ConstantOp::create(rewriter, loc, VectorType::get({emulatedPerContainerElem}, vectorElementType), rewriter.getZeroAttr( VectorType::get({emulatedPerContainerElem}, vectorElementType))); @@ -602,7 +602,7 @@ struct ConvertVectorStore final : OpConversionPattern { ShapedType::isDynamic(trailingDim) || trailingDim == origElements; auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); // FIXME: ATM, we do not test cases where offsets, sizes, or strides are // non-zero. As such, this is not needed. @@ -664,7 +664,7 @@ struct ConvertVectorStore final : OpConversionPattern { if (!emulationRequiresPartialStores) { // Basic case: storing full bytes. auto numElements = origElements / emulatedPerContainerElem; - auto bitCast = rewriter.create( + auto bitCast = vector::BitCastOp::create(rewriter, loc, VectorType::get(numElements, containerElemTy), op.getValueToStore()); rewriter.replaceOpWithNewOp( @@ -732,7 +732,7 @@ struct ConvertVectorStore final : OpConversionPattern { std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem, *foldedNumFrontPadElems, true); } - auto frontMask = rewriter.create( + auto frontMask = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems); @@ -751,8 +751,8 @@ struct ConvertVectorStore final : OpConversionPattern { // Increment the destination index by 1 to align to the emulated width // boundary. - auto constantOne = rewriter.create(loc, 1); - currentDestIndex = rewriter.create( + auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1); + currentDestIndex = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne); // 2. Full width store for the inner output bytes. @@ -772,15 +772,15 @@ struct ConvertVectorStore final : OpConversionPattern { auto storeType = VectorType::get( {originType.getNumElements() / emulatedPerContainerElem}, memrefElemType); - auto bitCast = rewriter.create(loc, storeType, + auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType, fullWidthStorePart); - rewriter.create(loc, bitCast.getResult(), memrefBase, + vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase, currentDestIndex); currentSourceIndex += numNonFullWidthElements; - currentDestIndex = rewriter.create( + currentDestIndex = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), currentDestIndex, - rewriter.create(loc, fullWidthStoreSize)); + arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize)); } // 3. Partial width store for the trailing output byte. @@ -795,7 +795,7 @@ struct ConvertVectorStore final : OpConversionPattern { // Generate back mask. auto maskValues = SmallVector(emulatedPerContainerElem, 0); std::fill_n(maskValues.begin(), remainingElements, 1); - auto backMask = rewriter.create( + auto backMask = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); storeFunc(rewriter, loc, memrefBase, currentDestIndex, @@ -848,7 +848,7 @@ struct ConvertVectorMaskedStore final return failure(); auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndicesOfr; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndicesOfr) = @@ -901,21 +901,21 @@ struct ConvertVectorMaskedStore final auto numElements = (origElements + emulatedPerContainerElem - 1) / emulatedPerContainerElem; auto newType = VectorType::get(numElements, containerElemTy); - auto passThru = rewriter.create( + auto passThru = arith::ConstantOp::create(rewriter, loc, newType, rewriter.getZeroAttr(newType)); - auto newLoad = rewriter.create( + auto newLoad = vector::MaskedLoadOp::create(rewriter, loc, newType, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), passThru); auto newBitCastType = VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); Value valueToStore = - rewriter.create(loc, newBitCastType, newLoad); - valueToStore = rewriter.create( + vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad); + valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(), op.getValueToStore(), valueToStore); valueToStore = - rewriter.create(loc, newType, valueToStore); + vector::BitCastOp::create(rewriter, loc, newType, valueToStore); rewriter.replaceOpWithNewOp( op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), @@ -990,7 +990,7 @@ struct ConvertVectorLoad final : OpConversionPattern { bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; @@ -1016,7 +1016,7 @@ struct ConvertVectorLoad final : OpConversionPattern { numElements, emulatedElemTy, containerElemTy); if (!foldedIntraVectorOffset) { - auto resultVector = rewriter.create( + auto resultVector = arith::ConstantOp::create(rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), resultVector, @@ -1111,7 +1111,7 @@ struct ConvertVectorMaskedLoad final bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = @@ -1142,7 +1142,7 @@ struct ConvertVectorMaskedLoad final auto newBitcastType = VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); - auto emptyVector = rewriter.create( + auto emptyVector = arith::ConstantOp::create(rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); if (!foldedIntraVectorOffset) { passthru = dynamicallyInsertSubVector( @@ -1153,10 +1153,10 @@ struct ConvertVectorMaskedLoad final *foldedIntraVectorOffset); } auto newPassThru = - rewriter.create(loc, loadType, passthru); + vector::BitCastOp::create(rewriter, loc, loadType, passthru); // Generating the new masked load. - auto newLoad = rewriter.create( + auto newLoad = vector::MaskedLoadOp::create(rewriter, loc, loadType, adaptor.getBase(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newMask.value()->getResult(0), newPassThru); @@ -1164,13 +1164,13 @@ struct ConvertVectorMaskedLoad final // Setting the part that originally was not effectively loaded from memory // to pass through. auto bitCast = - rewriter.create(loc, newBitcastType, newLoad); + vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad); Value mask = op.getMask(); auto newSelectMaskType = VectorType::get( numElements * emulatedPerContainerElem, rewriter.getI1Type()); // TODO: try to fold if op's mask is constant - auto emptyMask = rewriter.create( + auto emptyMask = arith::ConstantOp::create(rewriter, loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); if (!foldedIntraVectorOffset) { mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, @@ -1182,7 +1182,7 @@ struct ConvertVectorMaskedLoad final } Value result = - rewriter.create(loc, mask, bitCast, passthru); + arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru); if (!foldedIntraVectorOffset) { result = dynamicallyExtractSubVector( rewriter, loc, result, op.getPassThru(), @@ -1268,11 +1268,11 @@ struct ConvertVectorTransferRead final bool isDivisibleInSize = fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy); - auto newPadding = rewriter.create(loc, containerElemTy, + auto newPadding = arith::ExtUIOp::create(rewriter, loc, containerElemTy, adaptor.getPadding()); auto stridedMetadata = - rewriter.create(loc, op.getBase()); + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase()); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; @@ -1293,19 +1293,19 @@ struct ConvertVectorTransferRead final auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, emulatedPerContainerElem); - auto newRead = rewriter.create( + auto newRead = vector::TransferReadOp::create(rewriter, loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); - auto bitCast = rewriter.create( + auto bitCast = vector::BitCastOp::create(rewriter, loc, VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy), newRead); Value result = bitCast->getResult(0); if (!foldedIntraVectorOffset) { - auto zeros = rewriter.create( + auto zeros = arith::ConstantOp::create(rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType())); result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, @@ -1679,32 +1679,32 @@ Value BitCastRewriter::genericRewriteStep( PatternRewriter &rewriter, Location loc, Value initialValue, Value runningResult, const BitCastRewriter::Metadata &metadata) { // Create vector.shuffle from the metadata. - auto shuffleOp = rewriter.create( + auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue, initialValue, metadata.shuffles); // Intersect with the mask. VectorType shuffledVectorType = shuffleOp.getResultVectorType(); - auto constOp = rewriter.create( + auto constOp = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks)); - Value andValue = rewriter.create(loc, shuffleOp, constOp); + Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp); // Align right on 0. - auto shiftRightConstantOp = rewriter.create( + auto shiftRightConstantOp = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts)); Value shiftedRight = - rewriter.create(loc, andValue, shiftRightConstantOp); + arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp); // Shift bits left into their final position. - auto shiftLeftConstantOp = rewriter.create( + auto shiftLeftConstantOp = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts)); Value shiftedLeft = - rewriter.create(loc, shiftedRight, shiftLeftConstantOp); + arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp); runningResult = runningResult - ? rewriter.create(loc, runningResult, shiftedLeft) + ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft) : shiftedLeft; return runningResult; @@ -1727,7 +1727,7 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, // Adjust last dimension of the vector, so the total size remains the same. vecShape.back() = vecShape.back() / numSrcElemsPerByte; auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type()); - return rewriter.create(loc, i8VecType, subByteVec); + return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec); } /// Extracts a signed N-bit sequence from each element of a vector of bytes, @@ -1755,15 +1755,15 @@ static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 && "Invalid bitIdx range"); if (bitsToShiftLeft != 0) { - Value shiftLeftValues = rewriter.create( + Value shiftLeftValues = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft)); - shl = rewriter.create(loc, src, shiftLeftValues); + shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues); } int8_t bitsToShiftRight = 8 - numBits; - Value shiftRightValues = rewriter.create( + Value shiftRightValues = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); - Value shr = rewriter.create(loc, shl, shiftRightValues); + Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues); return shr; } @@ -1797,17 +1797,17 @@ static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, int8_t bitsToShiftRight = bitIdx; Value shr = src; if (bitsToShiftRight != 0) { - Value shiftRightValues = rewriter.create( + Value shiftRightValues = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); - shr = rewriter.create(loc, src, shiftRightValues); + shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues); } if (bitIdx + numBits == 8) { return shr; } uint8_t lowBitsMask = (1 << numBits) - 1; - Value lowBitsMaskValues = rewriter.create( + Value lowBitsMaskValues = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask)); - return rewriter.create(loc, shr, lowBitsMaskValues); + return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues); } using ExtractNBitsFn = @@ -1830,7 +1830,7 @@ static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value high = extFn(rewriter, loc, i8Vector, 4, 4); // 3. Interleave low and high i8 elements. - return rewriter.create(loc, low, high); + return vector::InterleaveOp::create(rewriter, loc, low, high); } /// Rewrite the i2 -> i8 extension into a sequence of shuffles and @@ -1863,9 +1863,9 @@ static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, // 02 = [0,2,0,2,0,2,0,2],... // 13 = [1,3,1,3,1,3,1,3],... // 0213 = [0,1,2,3,...],... - Value interleave02 = rewriter.create(loc, vec0, vec2); - Value interleave13 = rewriter.create(loc, vec1, vec3); - return rewriter.create(loc, interleave02, interleave13); + Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2); + Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3); + return vector::InterleaveOp::create(rewriter, loc, interleave02, interleave13); } /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise @@ -1877,29 +1877,29 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, "Expected i8 type"); // 1. De-interleave low and high i8 elements. - auto deinterleaveOp = rewriter.create(loc, srcValue); + auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue); // 2. Zero out the upper side of each low i8 element. constexpr int8_t i8LowBitMask = 0x0F; VectorType deinterI8VecType = deinterleaveOp.getResultVectorType(); - Value zeroOutMask = rewriter.create( + Value zeroOutMask = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask)); - Value zeroOutLow = rewriter.create( + Value zeroOutLow = arith::AndIOp::create(rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask); // 3. Move high i4 values to upper side of the byte. constexpr int8_t bitsToShift = 4; - auto shiftValues = rewriter.create( + auto shiftValues = arith::ConstantOp::create(rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift)); - Value shlHigh = rewriter.create(loc, deinterleaveOp.getRes2(), + Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(), shiftValues); // 4. Merge high and low i4 values. - auto mergedHiLowOp = rewriter.create(loc, zeroOutLow, shlHigh); + auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh); // 5. Generate a bitcast vector -> vector<2Xxi4>. auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type()); - return rewriter.create(loc, i4VecType, mergedHiLowOp); + return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp); } namespace { @@ -2141,7 +2141,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { Location loc = truncOp.getLoc(); auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type()); Value i8TruncVal = - rewriter.create(loc, i8VecType, srcValue); + arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue); // Rewrite the i8 -> i4 truncation part. Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal); @@ -2189,9 +2189,9 @@ struct RewriteVectorTranspose : OpRewritePattern { // support is available. auto srcNativeVecType = srcSubByteVecType.cloneWith( std::nullopt, rewriter.getIntegerType(minNativeBitwidth)); - Value extOp = rewriter.create(loc, srcNativeVecType, + Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType, transposeOp.getVector()); - Value newTranspose = rewriter.create( + Value newTranspose = vector::TransposeOp::create(rewriter, loc, extOp, transposeOp.getPermutation()); VectorType dstSubByteVecType = transposeOp.getResultVectorType(); rewriter.replaceOpWithNewOp(transposeOp, dstSubByteVecType, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index d834a99076834..baeca29298334 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -53,14 +53,14 @@ class DecomposeDifferentRankInsertStridedSlice int64_t rankRest = dstType.getRank() - rankDiff; // Extract / insert the subvector of matching rank and InsertStridedSlice // on it. - Value extracted = rewriter.create( + Value extracted = ExtractOp::create(rewriter, loc, op.getDest(), getI64SubArray(op.getOffsets(), /*dropFront=*/0, /*dropBack=*/rankRest)); // A different pattern will kick in for InsertStridedSlice with matching // ranks. - auto stridedSliceInnerOp = rewriter.create( + auto stridedSliceInnerOp = InsertStridedSliceOp::create(rewriter, loc, op.getValueToStore(), extracted, getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), getI64SubArray(op.getStrides(), /*dropFront=*/0)); @@ -131,7 +131,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle SmallVector offsets(nDest, 0); for (int64_t i = 0; i < nSrc; ++i) offsets[i] = i; - Value scaledSource = rewriter.create( + Value scaledSource = ShuffleOp::create(rewriter, loc, op.getValueToStore(), op.getValueToStore(), offsets); // 2. Create a mask where we take the value from scaledSource of dest @@ -156,21 +156,21 @@ class ConvertSameRankInsertStridedSliceIntoShuffle off += stride, ++idx) { // 1. extract the proper subvector (or element) from source Value extractedSource = - rewriter.create(loc, op.getValueToStore(), idx); + ExtractOp::create(rewriter, loc, op.getValueToStore(), idx); if (isa(extractedSource.getType())) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. Value extractedDest = - rewriter.create(loc, op.getDest(), off); + ExtractOp::create(rewriter, loc, op.getDest(), off); // 3. Reduce the problem to lowering a new InsertStridedSlice op with // smaller rank. - extractedSource = rewriter.create( + extractedSource = InsertStridedSliceOp::create(rewriter, loc, extractedSource, extractedDest, getI64SubArray(op.getOffsets(), /* dropFront=*/1), getI64SubArray(op.getStrides(), /* dropFront=*/1)); } // 4. Insert the extractedSource into the res vector. - res = rewriter.create(loc, extractedSource, res, off); + res = InsertOp::create(rewriter, loc, extractedSource, res, off); } rewriter.replaceOp(op, res); @@ -250,12 +250,12 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final SmallVector elements; elements.reserve(size); for (int64_t i = offset, e = offset + size * stride; i < e; i += stride) - elements.push_back(rewriter.create(loc, op.getVector(), i)); + elements.push_back(ExtractOp::create(rewriter, loc, op.getVector(), i)); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(op.getType())); for (int64_t i = 0; i < size; ++i) - result = rewriter.create(loc, elements[i], result, i); + result = InsertOp::create(rewriter, loc, elements[i], result, i); rewriter.replaceOp(op, result); return success(); @@ -301,17 +301,17 @@ class DecomposeNDExtractStridedSlice return failure(); // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = rewriter.create(loc, dstType, zero); + Value res = SplatOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { - Value one = rewriter.create(loc, op.getVector(), off); - Value extracted = rewriter.create( + Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); + Value extracted = ExtractStridedSliceOp::create(rewriter, loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), getI64SubArray(op.getSizes(), /* dropFront=*/1), getI64SubArray(op.getStrides(), /* dropFront=*/1)); - res = rewriter.create(loc, extracted, res, idx); + res = InsertOp::create(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, res); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7cac1cbafdd64..880506142bfdd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -634,7 +634,7 @@ struct LinearizeVectorCreateMask final // The result of the comparison is then multiplied with // the second operand of create_mask to get the 1D mask. auto firstOperand = adaptor.getOperands().front(); - auto zero = rewriter.create(loc, 0); + auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); auto isNonZero = rewriter.createOrFold( loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); auto isNonZeroIndex = rewriter.createOrFold( @@ -644,7 +644,7 @@ struct LinearizeVectorCreateMask final loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); auto newMask = - rewriter.create(loc, dstTy, maskSize); + mlir::vector::CreateMaskOp::create(rewriter, loc, dstTy, maskSize); rewriter.replaceOp(createMaskOp, newMask); return success(); } @@ -721,7 +721,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, if (!isa(type) || !isa(value.getType())) return nullptr; - return builder.create(loc, type, value); + return vector::ShapeCastOp::create(builder, loc, type, value); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp index 363108238e596..70d786a516859 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp @@ -84,7 +84,7 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter, // Replace createMaskOp with an all-true constant. This should result in the // mask being removed in most cases (as xfer ops + vector.mask have folds to // remove all-true masks). - auto allTrue = rewriter.create( + auto allTrue = vector::ConstantMaskOp::create(rewriter, createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue); rewriter.replaceAllUsesWith(createMaskOp, allTrue); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index efdae93e730bd..aa38fc20dfb70 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -286,7 +286,7 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, if (resultType.canonicalizeStridedLayout() == inputType.canonicalizeStridedLayout()) return input; - return rewriter.create(loc, resultType, input, offsets, + return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets, sizes, strides); } @@ -395,11 +395,11 @@ class TransferReadDropUnitDimsPattern Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); - Value c0 = rewriter.create(loc, 0); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); SmallVector inBounds(reducedVectorType.getRank(), true); - Operation *newTransferReadOp = rewriter.create( + Operation *newTransferReadOp = vector::TransferReadOp::create(rewriter, loc, reducedVectorType, reducedShapeSource, zeros, identityMap, transferReadOp.getPadding(), maskOp, rewriter.getBoolArrayAttr(inBounds)); @@ -477,13 +477,13 @@ class TransferWriteDropUnitDimsPattern } Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); - Value c0 = rewriter.create(loc, 0); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); SmallVector inBounds(reducedVectorType.getRank(), true); auto shapeCastSrc = rewriter.createOrFold( loc, reducedVectorType, vector); - Operation *newXferWrite = rewriter.create( + Operation *newXferWrite = vector::TransferWriteOp::create(rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); @@ -520,7 +520,7 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) collapsedIndices.push_back(i); reassociation.push_back(collapsedIndices); - return rewriter.create(loc, input, reassociation); + return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation); } /// Returns the new indices that collapses the inner dimensions starting from @@ -559,7 +559,7 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, // one would get the following offset: // %offset = %arg0 * 43 OpFoldResult collapsedOffset = - rewriter.create(loc, 0).getResult(); + arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); auto collapsedStrides = computeSuffixProduct( ArrayRef(shape.begin() + firstDimToCollapse, shape.end())); @@ -573,7 +573,7 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, if (auto value = dyn_cast(collapsedOffset)) { indicesAfterCollapsing.push_back(value); } else { - indicesAfterCollapsing.push_back(rewriter.create( + indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create(rewriter, loc, *getConstantIntValue(collapsedOffset))); } @@ -659,7 +659,7 @@ class FlattenContiguousRowMajorTransferReadPattern // 3. Create new vector.transfer_read that reads from the collapsed memref VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); - vector::TransferReadOp flatRead = rewriter.create( + vector::TransferReadOp flatRead = vector::TransferReadOp::create(rewriter, loc, flatVectorType, collapsedSource, collapsedIndices, transferReadOp.getPadding(), collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); @@ -757,9 +757,9 @@ class FlattenContiguousRowMajorTransferWritePattern VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); Value flatVector = - rewriter.create(loc, flatVectorType, vector); + vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector); vector::TransferWriteOp flatWrite = - rewriter.create( + vector::TransferWriteOp::create(rewriter, loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); @@ -855,7 +855,7 @@ class RewriteScalarExtractElementOfTransferRead newIndices[newIndices.size() - 1] = value; } else { newIndices[newIndices.size() - 1] = - rewriter.create(loc, + arith::ConstantIndexOp::create(rewriter, loc, *getConstantIntValue(ofr)); } } @@ -917,7 +917,7 @@ class RewriteScalarExtractOfTransferRead if (auto value = dyn_cast(composedIdx)) { newIndices[idx] = value; } else { - newIndices[idx] = rewriter.create( + newIndices[idx] = arith::ConstantIndexOp::create(rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx)); } } @@ -952,7 +952,7 @@ class RewriteScalarWrite : public OpRewritePattern { return failure(); // Only float and integer element types are supported. Value scalar = - rewriter.create(xferOp.getLoc(), xferOp.getVector()); + vector::ExtractOp::create(rewriter, xferOp.getLoc(), xferOp.getVector()); // Construct a scalar store. if (isa(xferOp.getBase().getType())) { rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index 256c8cb69b1ba..a1f36a9744932 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -64,12 +64,12 @@ static Value createInBoundsCond(RewriterBase &b, if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz) return; Value cond = - b.create(loc, arith::CmpIPredicate::sle, + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle, getValueOrCreateConstantIndexOp(b, loc, sum), getValueOrCreateConstantIndexOp(b, loc, dimSz)); // Conjunction over all dims for which we are in-bounds. if (inBoundsCond) - inBoundsCond = b.create(loc, inBoundsCond, cond); + inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond); else inBoundsCond = cond; }); @@ -177,11 +177,11 @@ static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, sourceType = MemRefType::get( sourceType.getShape(), sourceType.getElementType(), sourceType.getLayout(), compatibleMemRefType.getMemorySpace()); - res = b.create(memref.getLoc(), sourceType, res); + res = memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res); } if (sourceType == compatibleMemRefType) return res; - return b.create(memref.getLoc(), compatibleMemRefType, res); + return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res); } /// Operates under a scoped context to build the intersection between the @@ -203,15 +203,15 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { using MapList = ArrayRef>; Value dimMemRef = - b.create(xferOp.getLoc(), xferOp.getBase(), indicesIdx); - Value dimAlloc = b.create(loc, alloc, resultIdx); + memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx); + Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx); Value index = xferOp.getIndices()[indicesIdx]; AffineExpr i, j, k; bindDims(xferOp.getContext(), i, j, k); SmallVector maps = AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext()); // affine_min(%dimMemRef - %index, %dimAlloc) - Value affineMin = b.create( + Value affineMin = affine::AffineMinOp::create(b, loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); sizes.push_back(affineMin); }); @@ -220,9 +220,9 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; })); SmallVector destIndices(memrefRank, b.getIndexAttr(0)); SmallVector strides(memrefRank, b.getIndexAttr(1)); - auto copySrc = b.create( + auto copySrc = memref::SubViewOp::create(b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides); - auto copyDest = b.create( + auto copyDest = memref::SubViewOp::create(b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides); return std::make_pair(copySrc, copyDest); } @@ -251,18 +251,18 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); - Value zero = b.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); - return b.create( + return scf::IfOp::create(b, loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; llvm::append_range(viewAndIndices, xferOp.getIndices()); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{xferOp.getPadding()}, + linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()}, ValueRange{alloc}); // Take partial subview of memref which guarantees no dimension // overflows. @@ -270,13 +270,13 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, std::pair copyArgs = createSubViewIntersection( rewriter, cast(xferOp.getOperation()), alloc); - b.create(loc, copyArgs.first, copyArgs.second); + memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second); Value casted = castToCompatibleMemRefType(b, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }); } @@ -304,22 +304,22 @@ static scf::IfOp createFullPartialVectorTransferRead( Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); scf::IfOp fullPartialIfOp; - Value zero = b.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); - return b.create( + return scf::IfOp::create(b, loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; llvm::append_range(viewAndIndices, xferOp.getIndices()); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { Operation *newXfer = b.clone(*xferOp.getOperation()); Value vector = cast(newXfer).getVector(); - b.create( + memref::StoreOp::create(b, loc, vector, - b.create( + vector::TypeCastOp::create(b, loc, MemRefType::get({}, vector.getType()), alloc)); Value casted = @@ -327,7 +327,7 @@ static scf::IfOp createFullPartialVectorTransferRead( scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }); } @@ -351,7 +351,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); - Value zero = b.create(loc, 0); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); return b .create( @@ -361,7 +361,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; llvm::append_range(viewAndIndices, xferOp.getIndices()); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { Value casted = @@ -369,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); + scf::YieldOp::create(b, loc, viewAndIndices); }) ->getResults(); } @@ -391,15 +391,15 @@ static void createFullPartialLinalgCopy(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc) { Location loc = xferOp.getLoc(); - auto notInBounds = b.create( - loc, inBoundsCond, b.create(loc, true, 1)); - b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { + auto notInBounds = arith::XOrIOp::create(b, + loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1)); + scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) { IRRewriter rewriter(b); std::pair copyArgs = createSubViewIntersection( rewriter, cast(xferOp.getOperation()), alloc); - b.create(loc, copyArgs.first, copyArgs.second); - b.create(loc, ValueRange{}); + memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second); + scf::YieldOp::create(b, loc, ValueRange{}); }); } @@ -420,18 +420,18 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b, Value inBoundsCond, Value alloc) { Location loc = xferOp.getLoc(); - auto notInBounds = b.create( - loc, inBoundsCond, b.create(loc, true, 1)); - b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { + auto notInBounds = arith::XOrIOp::create(b, + loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1)); + scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) { IRMapping mapping; - Value load = b.create( + Value load = memref::LoadOp::create(b, loc, - b.create( + vector::TypeCastOp::create(b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc), ValueRange()); mapping.map(xferOp.getVector(), load); b.clone(*xferOp.getOperation(), mapping); - b.create(loc, ValueRange{}); + scf::YieldOp::create(b, loc, ValueRange{}); }); } @@ -561,7 +561,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( b.setInsertionPointToStart(&scope->getRegion(0).front()); auto shape = xferOp.getVectorType().getShape(); Type elementType = xferOp.getVectorType().getElementType(); - alloc = b.create(scope->getLoc(), + alloc = memref::AllocaOp::create(b, scope->getLoc(), MemRefType::get(shape, elementType), ValueRange{}, b.getI64IntegerAttr(32)); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index bcaea1c79471f..2e54c89df7800 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -381,7 +381,7 @@ FailureOr combineContractAndBroadcast(vector::ContractionOp contractOp, if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) return failure(); - Operation *newOp = rewriter.create( + Operation *newOp = vector::ContractionOp::create(rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(), rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); @@ -534,7 +534,7 @@ struct ReorderElementwiseOpsOnTranspose final // This is a constant. Create a reverse transpose op for it. auto vectorType = srcType.clone(cast(operand.getType()).getElementType()); - srcValues.push_back(rewriter.create( + srcValues.push_back(vector::TransposeOp::create(rewriter, operand.getLoc(), vectorType, operand, invOrder)); } } @@ -608,12 +608,12 @@ struct BubbleDownVectorBitCastForExtract // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> Location loc = extractOp.getLoc(); - Value packedValue = rewriter.create( + Value packedValue = vector::ExtractOp::create(rewriter, loc, castOp.getSource(), index / expandRatio); Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType, rewriter.getZeroAttr(packedVecType)); - packedValue = rewriter.create(loc, packedValue, zero, + packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero, /*position=*/0); // Cast it to a vector with the desired scalar's type. @@ -621,7 +621,7 @@ struct BubbleDownVectorBitCastForExtract VectorType packedType = VectorType::get({expandRatio}, castDstType.getElementType()); Value castedValue = - rewriter.create(loc, packedType, packedValue); + vector::BitCastOp::create(rewriter, loc, packedType, packedValue); // Finally extract the desired scalar. rewriter.replaceOpWithNewOp(extractOp, castedValue, @@ -700,7 +700,7 @@ struct BubbleDownBitCastForStridedSliceExtract VectorType newExtractType = VectorType::get(dims, castSrcType.getElementType()); - auto newExtractOp = rewriter.create( + auto newExtractOp = vector::ExtractStridedSliceOp::create(rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, newSizes, extractOp.getStrides()); @@ -761,7 +761,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern { isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio; VectorType newCastSrcType = VectorType::get(srcDims, castDstType.getElementType()); - auto newCastSrcOp = rewriter.create( + auto newCastSrcOp = vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore()); SmallVector dstDims(insertOp.getDestVectorType().getShape()); @@ -771,7 +771,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern { VectorType::get(dstDims, castDstType.getElementType()); // Bitcast the destination. - auto newCastDstOp = rewriter.create( + auto newCastDstOp = vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); // Generate new insert. @@ -852,7 +852,7 @@ struct BubbleUpBitCastForStridedSliceInsert VectorType newCastSrcType = VectorType::get(srcDims, castDstType.getElementType()); - auto newCastSrcOp = rewriter.create( + auto newCastSrcOp = vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore()); SmallVector dstDims = @@ -861,7 +861,7 @@ struct BubbleUpBitCastForStridedSliceInsert VectorType newCastDstType = VectorType::get(dstDims, castDstType.getElementType()); - auto newCastDstOp = rewriter.create( + auto newCastDstOp = vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); rewriter.replaceOpWithNewOp( @@ -936,9 +936,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern { Type elemType = castDstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); - Value zero = rewriter.create( + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = rewriter.create(loc, castDstType, zero); + Value res = SplatOp::create(rewriter, loc, castDstType, zero); SmallVector sliceShape = {castDstLastDim}; SmallVector strides = {1}; @@ -947,12 +947,12 @@ struct BreakDownVectorBitCast : public OpRewritePattern { castDstType.getElementType()); for (int i = 0, e = shrinkRatio; i < e; ++i) { - Value extracted = rewriter.create( + Value extracted = ExtractStridedSliceOp::create(rewriter, loc, bitcastOp.getSource(), ArrayRef{i * castDstLastDim}, sliceShape, strides); Value bitcast = - rewriter.create(loc, newCastDstType, extracted); - res = rewriter.create( + BitCastOp::create(rewriter, loc, newCastDstType, extracted); + res = InsertStridedSliceOp::create(rewriter, loc, bitcast, res, ArrayRef{i * castDstLastDim / shrinkRatio}, strides); } @@ -1097,7 +1097,7 @@ class ExtractOpFromElementwise final Location loc = eltwise->getLoc(); SmallVector pos = op.getMixedPosition(); for (Value arg : eltwise->getOperands()) { - Value newArg = rewriter.create(loc, arg, pos); + Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos); mapping.map(arg, newArg); } @@ -1286,18 +1286,18 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, indicesAttr = rewriter.getI64VectorAttr( llvm::to_vector<4>(llvm::seq(0, dim))); } - Value indices = rewriter.create(loc, indicesAttr); + Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr); // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = rewriter.create(loc, indices.getType(), o); - indices = rewriter.create(loc, ov, indices); + Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - rewriter.create(loc, indices.getType(), bound); - return rewriter.create(loc, arith::CmpIPredicate::slt, indices, + vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); } @@ -1329,15 +1329,15 @@ struct MaterializeTransferMask : public OpRewritePattern { Value off = xferOp.getIndices()[lastIndex]; Value dim = vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex); - Value b = rewriter.create(loc, dim.getType(), dim, off); - Value mask = rewriter.create( + Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off); + Value mask = vector::CreateMaskOp::create(rewriter, loc, VectorType::get(vtp.getShape(), rewriter.getI1Type(), vtp.getScalableDims()), b); if (xferOp.getMask()) { // Intersect the in-bounds with the mask specified as an op parameter. - mask = rewriter.create(loc, mask, xferOp.getMask()); + mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask()); } rewriter.modifyOpInPlace(xferOp, [&]() { @@ -1542,11 +1542,11 @@ class DropInnerMostUnitDimsTransferRead strides); ArrayAttr inBoundsAttr = rewriter.getArrayAttr( readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); - Value rankedReducedView = rewriter.create( + Value rankedReducedView = memref::SubViewOp::create(rewriter, loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); - Value result = rewriter.create( + Value result = vector::TransferReadOp::create(rewriter, loc, resultTargetVecType, rankedReducedView, readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), readOp.getPadding(), @@ -1633,7 +1633,7 @@ class DropInnerMostUnitDimsTransferWrite ArrayAttr inBoundsAttr = rewriter.getArrayAttr( writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); - Value rankedReducedView = rewriter.create( + Value rankedReducedView = memref::SubViewOp::create(rewriter, loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); @@ -1702,21 +1702,21 @@ struct CanonicalizeContractMatmulToMMT final auto createTranspose = [&rewriter, loc](Value mat) -> Value { if (auto sext = mat.getDefiningOp()) { Value trans = - rewriter.create(loc, sext.getIn(), perm); + vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm); VectorType newType = cast(trans.getType()) .clone(cast(mat.getType()).getElementType()); - return rewriter.create(loc, newType, trans); + return arith::ExtSIOp::create(rewriter, loc, newType, trans); } if (auto zext = mat.getDefiningOp()) { Value trans = - rewriter.create(loc, zext.getIn(), perm); + vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm); VectorType newType = VectorType::get(cast(trans.getType()).getShape(), cast(mat.getType()).getElementType()); - return rewriter.create(loc, newType, trans); + return arith::ExtUIOp::create(rewriter, loc, newType, trans); } - return rewriter.create(loc, mat, perm); + return vector::TransposeOp::create(rewriter, loc, mat, perm); }; if (maps == infer({{m, k}, {k, n}, {m, n}})) { @@ -1830,7 +1830,7 @@ struct ChainedReduction final : OpRewritePattern { vAdd = rewriter.createOrFold( loc, parentReduction.getVector(), op.getVector()); } else { - vAdd = rewriter.create(loc, parentReduction.getVector(), + vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(), op.getVector()); } rewriter.replaceOpWithNewOp(op, op.getKind(), vAdd, @@ -1919,7 +1919,7 @@ struct DropUnitDimFromElementwiseOps final if (newVType == opVectorType) return rewriter.notifyMatchFailure(op, "No unit dimension to remove."); - auto opSC = rewriter.create(loc, newVType, operand); + auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand); newOperands.push_back(opSC); } @@ -1998,11 +1998,11 @@ struct DropUnitDimsFromTransposeOp final Location loc = op.getLoc(); // Drop the unit dims via shape_cast. - auto dropDimsShapeCast = rewriter.create( + auto dropDimsShapeCast = vector::ShapeCastOp::create(rewriter, loc, sourceTypeWithoutUnitDims, op.getVector()); // Create the new transpose. auto transposeWithoutUnitDims = - rewriter.create(loc, dropDimsShapeCast, newPerm); + vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm); // Restore the unit dims via shape cast. rewriter.replaceOpWithNewOp( op, op.getResultVectorType(), transposeWithoutUnitDims); @@ -2053,7 +2053,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern { // Create a new ForOp with that iter operand replaced. auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) { - return b.create(loc, type, source); + return vector::ShapeCastOp::create(b, loc, type, source); }; Value replacement = @@ -2105,7 +2105,7 @@ struct ReduceRedundantZero final : OpRewritePattern { if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat())) return failure(); - auto newAdd = rewriter.create(vAdd.getLoc(), addLhs.getLhs(), + auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(), addLhs.getLhs(), vAdd.getRhs()); rewriter.replaceOpWithNewOp(op, op.getKind(), newAdd, op.getAcc()); @@ -2148,7 +2148,7 @@ struct BreakDownVectorReduction final : OpRewritePattern { Location loc = op.getLoc(); SmallVector extracted(numElems, nullptr); for (auto [idx, extractedElem] : llvm::enumerate(extracted)) - extractedElem = rewriter.create( + extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(), static_cast(idx)); Value res = extracted.front(); @@ -2228,7 +2228,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) return failure(); - return rewriter.create( + return vector::OuterProductOp::create(rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(), broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 693f4f955994d..9fa01061c1bed 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -49,7 +49,7 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, getAffineConstantExpr(elementOffsets[dim.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); slicedIndices[pos] = - builder.create(loc, map, indices[pos]); + affine::AffineApplyOp::create(builder, loc, map, indices[pos]); } return slicedIndices; } @@ -68,9 +68,9 @@ static SmallVector sliceLoadStoreIndices(PatternRewriter &rewriter, auto start = indices.size() - offsets.size(); for (auto [i, offset] : llvm::enumerate(offsets)) { if (offset != 0) { - indices[start + i] = rewriter.create( + indices[start + i] = arith::AddIOp::create(rewriter, loc, originalIndices[start + i], - rewriter.create(loc, offset)); + arith::ConstantIndexOp::create(rewriter, loc, offset)); } } return indices; @@ -172,7 +172,7 @@ struct UnrollTransferReadPattern ArrayRef originalSize = readOp.getVectorType().getShape(); // Prepare the result vector; - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); @@ -185,7 +185,7 @@ struct UnrollTransferReadPattern SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); - auto slicedRead = rewriter.create( + auto slicedRead = vector::TransferReadOp::create(rewriter, loc, targetType, readOp.getBase(), indices, readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); @@ -236,7 +236,7 @@ struct UnrollTransferWritePattern SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, writeOp.getPermutationMap(), loc, rewriter); - Operation *slicedWrite = rewriter.create( + Operation *slicedWrite = vector::TransferWriteOp::create(rewriter, loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(), indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); // For the tensor case update the destination for the next transfer write. @@ -348,7 +348,7 @@ struct UnrollContractionPattern accCache[dstOffets] = newOp->getResult(0); } // Assemble back the accumulator into a single vector. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); @@ -427,7 +427,7 @@ struct UnrollMultiReductionPattern accCache[destOffset] = result; } // Assemble back the accumulator into a single vector. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, reductionOp.getDestType(), rewriter.getZeroAttr(reductionOp.getDestType())); for (const auto &it : accCache) { @@ -468,7 +468,7 @@ struct UnrollElementwisePattern : public RewritePattern { op, "expected input vector rank to match target shape rank"); Location loc = op->getLoc(); // Prepare the result vector. - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); SmallVector strides(targetShape->size(), 1); VectorType newVecType = @@ -567,7 +567,7 @@ struct UnrollTransposePattern : public OpRewritePattern { ArrayRef originalSize = originalVectorType.getShape(); // Prepare the result vector; - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); ArrayRef permutation = transposeOp.getPermutation(); @@ -618,7 +618,7 @@ struct UnrollGatherPattern : public OpRewritePattern { ArrayRef originalSize = gatherOp.getVectorType().getShape(); // Prepare the result vector; - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); @@ -638,7 +638,7 @@ struct UnrollGatherPattern : public OpRewritePattern { rewriter.createOrFold( loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); - auto slicedGather = rewriter.create( + auto slicedGather = vector::GatherOp::create(rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), indexSubVec, maskSubVec, passThruSubVec); @@ -671,7 +671,7 @@ struct UnrollLoadPattern : public OpRewritePattern { ArrayRef originalShape = vecType.getShape(); SmallVector strides(targetShape->size(), 1); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, vecType, rewriter.getZeroAttr(vecType)); SmallVector loopOrder = @@ -684,7 +684,7 @@ struct UnrollLoadPattern : public OpRewritePattern { StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { SmallVector indices = sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets); - Value slicedLoad = rewriter.create( + Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType, loadOp.getBase(), indices); result = rewriter.createOrFold( loc, slicedLoad, result, offsets, strides); @@ -727,7 +727,7 @@ struct UnrollStorePattern : public OpRewritePattern { sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets); Value slice = rewriter.createOrFold( loc, vector, offsets, *targetShape, strides); - rewriter.create(loc, slice, base, indices); + vector::StoreOp::create(rewriter, loc, slice, base, indices); } rewriter.eraseOp(storeOp); return success(); @@ -755,7 +755,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern { VectorType resType = broadcastOp.getResultVectorType(); VectorType targetType = resType.cloneWith(*targetShape, resType.getElementType()); - Value result = rewriter.create( + Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); SmallVector originalShape = *broadcastOp.getShapeForUnroll(); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 7e4984582b373..7c31ada1644c5 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -331,7 +331,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = inputVectorSizes.size(); - auto zero = builder.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); SmallVector inBoundsVal(readRank, true); if (useInBoundsInsteadOfMasking) { @@ -341,7 +341,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && ShapedType::isStatic(sourceShape[i]); } - auto transferReadOp = builder.create( + auto transferReadOp = vector::TransferReadOp::create(builder, loc, /*vectorType=*/vectorType, /*source=*/source, @@ -356,7 +356,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type()); Value mask = - builder.create(loc, maskType, mixedSourceDims); + vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) ->getResult(0); } diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index cc7ab7f3f3895..f7b1c0966fe27 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -63,11 +64,11 @@ SmallVector x86vector::MaskCompressOp::getIntrinsicOperands( if (adaptor.getSrc()) { src = adaptor.getSrc(); } else if (adaptor.getConstantSrc()) { - src = rewriter.create(loc, opType, - adaptor.getConstantSrcAttr()); + src = LLVM::ConstantOp::create(rewriter, loc, opType, + adaptor.getConstantSrcAttr()); } else { auto zeroAttr = rewriter.getZeroAttr(opType); - src = rewriter.create(loc, opType, zeroAttr); + src = LLVM::ConstantOp::create(rewriter, loc, opType, zeroAttr); } return SmallVector{adaptor.getA(), src, adaptor.getK()}; @@ -80,7 +81,7 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef operands, SmallVector intrinsicOperands(operands); // Dot product of all elements, broadcasted to all elements. Value scale = - rewriter.create(getLoc(), rewriter.getI8Type(), 0xff); + LLVM::ConstantOp::create(rewriter, getLoc(), rewriter.getI8Type(), 0xff); intrinsicOperands.push_back(scale); return intrinsicOperands; diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp index 63f725084cabf..8f87661a68d96 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -38,7 +38,7 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( "=x,x,x"; // Careful: constraint parser is very brittle: no ws! SmallVector asmVals{v1, v2}; auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); - auto asmOp = b.create( + auto asmOp = LLVM::InlineAsmOp::create(b, v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr, /*constraints=*/asmCstr, /*has_side_effects=*/false, /*is_align_stack=*/false, LLVM::TailCallKind::None, @@ -49,13 +49,13 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2) { - return b.create( + return vector::ShuffleOp::create(b, v1, v2, ArrayRef{0, 8, 1, 9, 4, 12, 5, 13}); } Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2) { - return b.create( + return vector::ShuffleOp::create(b, v1, v2, ArrayRef{2, 10, 3, 11, 6, 14, 7, 15}); } /// a a b b a a b b @@ -69,7 +69,7 @@ Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, MaskHelper::extractShuffle(mask, b01, b23, b45, b67); SmallVector shuffleMask = { b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; - return b.create(v1, v2, shuffleMask); + return vector::ShuffleOp::create(b, v1, v2, shuffleMask); } // imm[0:1] out of imm[0:3] is: @@ -97,7 +97,7 @@ Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( MaskHelper::extractPermute(mask, b03, b47); appendToMask(b03); appendToMask(b47); - return b.create(v1, v2, shuffleMask); + return vector::ShuffleOp::create(b, v1, v2, shuffleMask); } /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. @@ -109,7 +109,7 @@ Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, bool isSet = mask & (1 << i); shuffleMask.push_back(!isSet ? i : i + 8); } - return b.create(v1, v2, shuffleMask); + return vector::ShuffleOp::create(b, v1, v2, shuffleMask); } /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. @@ -245,13 +245,13 @@ class TransposeOpLowering : public OpRewritePattern { VectorType::get({n * m}, op.getSourceVectorType().getElementType()); auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); auto reshInput = - ib.create(flattenedType, op.getVector()); - reshInput = ib.create(reshInputType, reshInput); + vector::ShapeCastOp::create(ib, flattenedType, op.getVector()); + reshInput = vector::ShapeCastOp::create(ib, reshInputType, reshInput); // Extract 1-D vectors from the higher-order dimension of the input // vector. for (int64_t i = 0; i < m; ++i) - vs.push_back(ib.create(reshInput, i)); + vs.push_back(vector::ExtractOp::create(ib, reshInput, i)); // Transpose set of 1-D vectors. if (m == 4) @@ -261,16 +261,16 @@ class TransposeOpLowering : public OpRewritePattern { // Insert transposed 1-D vectors into the higher-order dimension of the // output vector. - Value res = ib.create(reshInputType, + Value res = arith::ConstantOp::create(ib, reshInputType, ib.getZeroAttr(reshInputType)); for (int64_t i = 0; i < m; ++i) - res = ib.create(vs[i], res, i); + res = vector::InsertOp::create(ib, vs[i], res, i); // The output vector still has the shape of the input vector (e.g., 4x8). // We have to transpose their dimensions and retrieve its original rank // (e.g., 1x8x1x4x1). - res = ib.create(flattenedType, res); - res = ib.create(op.getResultVectorType(), res); + res = vector::ShapeCastOp::create(ib, flattenedType, res); + res = vector::ShapeCastOp::create(ib, op.getResultVectorType(), res); rewriter.replaceOp(op, res); return success(); }; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 3bea6cb50ff7b..81687a6fc37f2 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/Support/Debug.h" @@ -416,7 +417,7 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state, int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); - auto offset = builder.create(loc, type, values); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, TensorDesc, source, offset); } @@ -553,7 +554,7 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, int64_t size = static_cast(offsets.size()); auto type = VectorType::get({size}, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); - auto offset = builder.create(loc, type, values); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, tdescTy, tensorDesc, offset); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index c072557c2bd22..613a33fe63d87 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -120,13 +120,13 @@ static Value resolveDistributedTy(Value orig, T expected, // If orig is a vector type, create a shape cast op to reconcile the types. if (isa(orig.getType())) { auto castOp = - rewriter.create(orig.getLoc(), expected, orig); + vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig); return castOp.getResult(); } // If orig is a tensor descriptor type, create an unrealized conversion cast // op to reconcile the types. if (isa(orig.getType())) { - auto castOp = rewriter.create(orig.getLoc(), + auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(), expected, orig); castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr()); return castOp.getResult(0); @@ -198,16 +198,16 @@ struct MoveFuncBodyToWarpExecuteOnLane0 })) return failure(); // Create a new function with the same signature. - auto newGpuFunc = rewriter.create( + auto newGpuFunc = gpu::GPUFuncOp::create(rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType()); // Create a WarpExecuteOnLane0Op with same arguments and results as the // original gpuFuncOp. rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front()); - auto laneId = rewriter.create( + auto laneId = gpu::LaneIdOp::create(rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(), /** upperBound = **/ mlir::IntegerAttr()); ArrayRef gpuFuncResultType = gpuFuncOp.getFunctionType().getResults(); - auto warpOp = rewriter.create( + auto warpOp = gpu::WarpExecuteOnLane0Op::create(rewriter, laneId.getLoc(), gpuFuncResultType, laneId, xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes()); @@ -216,7 +216,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0 auto origRetunOp = cast(gpuFuncOp.getBlocks().back().getTerminator()); rewriter.setInsertionPointAfter(origRetunOp); - rewriter.create(origRetunOp.getLoc(), + gpu::YieldOp::create(rewriter, origRetunOp.getLoc(), origRetunOp.getOperands()); rewriter.eraseOp(origRetunOp); // Move the original function body to the WarpExecuteOnLane0Op body. @@ -225,7 +225,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0 rewriter.eraseBlock(&warpBodyBlock); // Insert a new ReturnOp after the WarpExecuteOnLane0Op. rewriter.setInsertionPointAfter(warpOp); - rewriter.create(newGpuFunc.getLoc(), warpOp.getResults()); + gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults()); rewriter.replaceOp(gpuFuncOp, newGpuFunc); return success(); } @@ -301,7 +301,7 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { xegpu::TensorDescType distributedTensorDescTy = descOp.getType().dropLayouts(); // Distributed tensor descriptor type // does not contain layout info. - Value newDescOp = rewriter.create( + Value newDescOp = xegpu::CreateNdDescOp::create(rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands, descOp->getAttrs()); @@ -403,7 +403,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]), distributedTensorDescTy, rewriter)); - rewriter.create( + xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreOperands, removeTemporaryLayoutAttributes(storeOp->getAttrs())); rewriter.eraseOp(storeOp); @@ -494,7 +494,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { loadOp.getTensorDescType().dropLayouts(); // Distributed tensor // descriptor type does not // contain layout info. - auto newLoadOp = rewriter.create( + auto newLoadOp = xegpu::LoadNdOp::create(rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(), resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]), distributedTensorDescTy, rewriter), @@ -630,7 +630,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern { resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]), newDpasOperandExpectedTypes[i], rewriter)); } - Value newDpasOp = rewriter.create( + Value newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(), distributedResultTy, newDpasOperands, removeTemporaryLayoutAttributes(dpasOp->getAttrs())); Value distributedVal = newWarpOp.getResult(operandIdx); @@ -717,7 +717,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { } } // Create a new update op outside the warp op. - Value newUpdateOp = rewriter.create( + Value newUpdateOp = xegpu::UpdateNdOffsetOp::create(rewriter, newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands, removeTemporaryLayoutAttributes(updateOp->getAttrs())); Value distributedVal = newWarpOp.getResult(operandIdx); @@ -783,7 +783,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); SmallVector newPrefetchOperands = {resolveDistributedTy( newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)}; - rewriter.create( + xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands, removeTemporaryLayoutAttributes(prefetchOp->getAttrs())); rewriter.eraseOp(prefetchOp); @@ -806,7 +806,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { return failure(); // Move the barrier op outside of the warp op. rewriter.setInsertionPointAfter(subgroupOp); - rewriter.create( + gpu::BarrierOp::create(rewriter, barrierOp.getLoc(), barrierOp->getResultTypes(), barrierOp->getOperands(), barrierOp->getAttrs()); rewriter.eraseOp(barrierOp); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 2c48a735bf956..1d9f4dcc6b9e6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -83,7 +83,7 @@ struct UnrollPattern : public OpRewritePattern { rewriter.getUnitAttr()); auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), rewriter.getDenseI64ArrayAttr(blockSize)); - auto castOp = rewriter.create( + auto castOp = UnrealizedConversionCastOp::create(rewriter, loc, destTy, srcs, ArrayRef({attr, blkAttr})); return castOp.getResult(0); } @@ -109,7 +109,7 @@ struct UnrollPattern : public OpRewritePattern { rewriter.getUnitAttr()); auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), rewriter.getDenseI64ArrayAttr(blockSize)); - auto castOp = rewriter.create( + auto castOp = UnrealizedConversionCastOp::create(rewriter, loc, destTypes, src, ArrayRef({attr, blkAttr})); return castOp.getResults(); } @@ -144,10 +144,10 @@ struct UnrollCreateNdOp : public UnrollPattern { auto addi = [&](OpFoldResult a, int64_t b) -> Value { std::optional maybeInt = getConstantIntValue(a); if (maybeInt) { - return rewriter.create(loc, *maybeInt + b); + return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b); } else { auto aV = llvm::cast(a); - auto bV = rewriter.create(loc, b); + auto bV = arith::ConstantIndexOp::create(rewriter, loc, b); return rewriter.createOrFold(loc, aV, bV); } }; @@ -169,7 +169,7 @@ struct UnrollCreateNdOp : public UnrollPattern { llvm::zip(validIdxes, oldOffsets, offsets)) mixedOffsets[idx] = addi(oldOff, offset); - auto newOp = rewriter.create( + auto newOp = xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), op.getMixedStrides()); newOps.push_back(newOp); @@ -199,7 +199,7 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern { SmallVector newOps; for (auto t : convertedTdesc) { - auto newOp = rewriter.create( + auto newOp = xegpu::UpdateNdOffsetOp::create(rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets()); newOps.push_back(newOp); } @@ -226,7 +226,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern { op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); for (auto t : convertedTdesc) - rewriter.create(loc, TypeRange(), t, op->getAttrs()); + xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t, op->getAttrs()); rewriter.eraseOp(op); return success(); @@ -257,7 +257,7 @@ struct UnrollLoadNdOp : public UnrollPattern { SmallVector newOps; for (auto t : convertedTdescs) { auto newOp = - rewriter.create(loc, newValueTy, t, op->getAttrs()); + xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs()); newOps.push_back(newOp); } @@ -291,7 +291,7 @@ struct UnrollStoreNdOp : public UnrollPattern { op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) - rewriter.create(loc, v, t, op.getL1HintAttr(), + xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); rewriter.eraseOp(op); @@ -384,7 +384,7 @@ struct UnrollDpasOp : public UnrollPattern { if (tmpC) operands.push_back(tmpC); - tmpC = rewriter.create(loc, vecTy, operands, + tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands, op->getAttrs()); } newOps.push_back(tmpC); @@ -436,13 +436,13 @@ struct UnrollCreateDescOp : public UnrollPattern { llvm::zip(convertedIndiceVec, convertedIndiceTypes)) { for (int64_t i = 0; i < numNewChunks; ++i) { // Compute the offset - Value inc = rewriter.create( + Value inc = arith::ConstantIndexOp::create(rewriter, loc, i * blockedChunkSize); - Value incVec = rewriter.create(loc, indiceType, inc); + Value incVec = vector::SplatOp::create(rewriter, loc, indiceType, inc); Value offsetIndice = - rewriter.create(loc, indice, incVec); + arith::AddIOp::create(rewriter, loc, indice, incVec); - auto newOp = rewriter.create( + auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy, op.getSource(), offsetIndice); newOps.push_back(newOp); @@ -450,7 +450,7 @@ struct UnrollCreateDescOp : public UnrollPattern { } } else { for (auto indice : convertedIndiceVec) { - auto newOp = rewriter.create( + auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy, op.getSource(), indice); newOps.push_back(newOp); } @@ -515,7 +515,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector newOps; for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) { - auto newOp = rewriter.create( + auto newOp = xegpu::LoadGatherOp::create(rewriter, loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newOps.push_back(newOp); @@ -547,7 +547,7 @@ struct UnrollPrefetchOp : public UnrollPattern { op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); for (auto t : convertedTdesc) - rewriter.create(loc, TypeRange(), t, op->getAttrs()); + xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t, op->getAttrs()); rewriter.eraseOp(op); return success(); @@ -608,7 +608,7 @@ struct UnrollStoreScatterOp : public UnrollPattern { Value v = convertedValues[i]; Value t = convertedTdescs[i]; Value m = op.getMask() ? convertedMasks[i] : nullptr; - rewriter.create(loc, v, t, m, op.getL1HintAttr(), + xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } @@ -665,7 +665,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) { auto newOp = - rewriter.create(loc, t.getType(), t, o); + xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o); newOps.push_back(newOp); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index be7b860dd1729..aa2fc78c77f39 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -121,11 +121,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern { for (size_t i = 0; i < rank; ++i) { size_t dimIdx = originalOffsets.size() - rank + i; Value constOffset = - rewriter.create(loc, distUnitBaseAddr[i]); + arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]); Value offset = rewriter.createOrFold(loc, localOffset[i], constOffset); Value modValue = - rewriter.create(loc, distUnitShape[i]); + arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]); Value offsetMod = rewriter.createOrFold(loc, offset, modValue); Value origOffset = getValueOrCreateConstantIndexOp( @@ -162,7 +162,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern { // TODO : Handle order attribute // Get the subgroup ID auto linearSgId = - rewriter.create(loc, /*upper_bound=*/nullptr); + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); // Create constants for layout dimensions SmallVector sgLayoutDim(sgLayout.size()); @@ -170,8 +170,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern { for (size_t i = 0; i < sgLayout.size(); i++) { sgLayoutDim[i] = - rewriter.create(loc, sgLayout[i]); - sgDataDim[i] = rewriter.create(loc, sgShape[i]); + arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]); + sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); } auto deLinearizeSgId = @@ -201,7 +201,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern { calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr, distUnitShape); - auto newCreateNdOp = rewriter.create( + auto newCreateNdOp = xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(), op.getMixedStrides()); newCreateNdOps.push_back(newCreateNdOp); @@ -224,7 +224,7 @@ struct WgToSgLoadNdOp : public OpConversionPattern { dyn_cast(src.getType()); ArrayRef srcShape = tdescTy.getShape(); VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); - auto newLoadOp = rewriter.create(op.getLoc(), newResTy, + auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy, src, op->getAttrs()); newLoadOps.push_back(newLoadOp); } @@ -242,7 +242,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern { matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) - rewriter.create(op.getLoc(), v, t, op.getL1HintAttr(), + xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); rewriter.eraseOp(op); @@ -261,7 +261,7 @@ struct WgToSgUpdateNdOffsetOp ConversionPatternRewriter &rewriter) const override { llvm::SmallVector newUpdateTileOffsetOps; for (auto tDesc : adaptor.getTensorDesc()) { - auto newUpdateTileOffsetOp = rewriter.create( + auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), op.getConstOffsets()); newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); @@ -305,7 +305,7 @@ struct WgToSgDpasOp : public OpConversionPattern { llvm::cast(bVec.getType()).getShape(); VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); - tmpC = rewriter.create(loc, resTy, operands); + tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands); xegpu::setLayoutAttr(cast(tmpC), originalLayout.dropSgLayoutAndData()); @@ -324,7 +324,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { for (auto src : adaptor.getTensorDesc()) - rewriter.create(op.getLoc(), TypeRange(), src, + xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src, op->getAttrs()); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 6b85a66a8bd36..9acb209d031bd 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -198,7 +198,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, SmallVector result; for (SmallVector offsets : StaticTileOffsetRange(srcShape, shape)) { SmallVector staticStrides(offsets.size(), 1); - result.push_back(builder.create( + result.push_back(vector::ExtractStridedSliceOp::create(builder, loc, value, offsets, shape, staticStrides)); } @@ -218,13 +218,13 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, VectorType resultTy = VectorType::get(shape, elemTy); auto zeroAttr = builder.getZeroAttr(elemTy); - Value result = builder.create( + Value result = arith::ConstantOp::create(builder, loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr)); for (auto [src, offsets] : llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { SmallVector staticStrides(offsets.size(), 1); - result = builder.create( + result = vector::InsertStridedSliceOp::create(builder, loc, src, result, offsets, staticStrides); } return result; @@ -236,7 +236,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, type, inputs) + return UnrealizedConversionCastOp::create(builder, loc, type, inputs) .getResult(0); }; @@ -343,7 +343,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( if (isa(inputTy) && isa(outputTy)) { SmallVector values = xegpu::flattenValues(adaptor.getInputs()); - auto newOp = rewriter.create( + auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(), outputTy, values); rewriter.replaceOp(op, newOp); return success(); @@ -355,7 +355,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( converter.addSourceMaterialization(materializeCast); converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type, ValueRange inputs, Location loc) { - return builder.create(loc, type, inputs) + return UnrealizedConversionCastOp::create(builder, loc, type, inputs) .getResults(); }); diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index 6d7e2aa0ece7d..23988a00c7a2c 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeRange.h" diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp index 637e1f3cdef87..8ddb35d75734b 100644 --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -78,7 +78,7 @@ static Operation *extractFunction(std::vector &ops, clonedOp->result_end()); } // Add return operation - builder.create(loc, clonedVals); + func::ReturnOp::create(builder, loc, clonedVals); // Remove unused function arguments size_t currentIndex = 0; diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp index 4e7f1d3185129..c0f19f6dce600 100644 --- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp @@ -37,7 +37,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( return failure(); Type resultType = moduleImport.convertType(inst->getType()); - auto op = builder.create<::mlir::LLVM::CallIntrinsicOp>( + auto op = ::mlir::LLVM::CallIntrinsicOp::create(builder, moduleImport.translateLoc(inst->getDebugLoc()), isa(resultType) ? TypeRange{} : TypeRange{resultType}, StringAttr::get(builder.getContext(), intrinName), diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index bfda223fe0f5f..6e47db6d1dc6a 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -185,7 +185,7 @@ ComdatOp ModuleImport::getGlobalComdatOp() { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(mlirModule.getBody()); globalComdatOp = - builder.create(mlirModule.getLoc(), getGlobalComdatOpName()); + ComdatOp::create(builder, mlirModule.getLoc(), getGlobalComdatOpName()); globalInsertionOp = globalComdatOp; return globalComdatOp; } @@ -864,7 +864,7 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() { } if (!moduleFlags.empty()) - builder.create(mlirModule.getLoc(), + LLVM::ModuleFlagsOp::create(builder, mlirModule.getLoc(), builder.getArrayAttr(moduleFlags)); return success(); @@ -880,7 +880,7 @@ LogicalResult ModuleImport::convertLinkerOptionsMetadata() { options.reserve(node->getNumOperands()); for (const llvm::MDOperand &option : node->operands()) options.push_back(cast(option)->getString()); - builder.create(mlirModule.getLoc(), + LLVM::LinkerOptionsOp::create(builder, mlirModule.getLoc(), builder.getStrArrayAttr(options)); } } @@ -984,7 +984,7 @@ void ModuleImport::processComdat(const llvm::Comdat *comdat) { ComdatOp comdatOp = getGlobalComdatOp(); OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(&comdatOp.getBody().back()); - auto selectorOp = builder.create( + auto selectorOp = ComdatSelectorOp::create(builder, mlirModule.getLoc(), comdat->getName(), convertComdatFromLLVM(comdat->getSelectionKind())); auto symbolRef = @@ -1346,7 +1346,7 @@ LogicalResult ModuleImport::convertAlias(llvm::GlobalAlias *alias) { OpBuilder::InsertionGuard guard = setGlobalInsertionPoint(); Type type = convertType(alias->getValueType()); - AliasOp aliasOp = builder.create( + AliasOp aliasOp = AliasOp::create(builder, mlirModule.getLoc(), type, convertLinkageFromLLVM(alias->getLinkage()), alias->getName(), /*dso_local=*/alias->isDSOLocal(), @@ -1360,7 +1360,7 @@ LogicalResult ModuleImport::convertAlias(llvm::GlobalAlias *alias) { FailureOr initializer = convertConstantExpr(alias->getAliasee()); if (failed(initializer)) return failure(); - builder.create(aliasOp.getLoc(), *initializer); + ReturnOp::create(builder, aliasOp.getLoc(), *initializer); if (alias->hasAtLeastLocalUnnamedAddr()) aliasOp.setUnnamedAddr(convertUnnamedAddrFromLLVM(alias->getUnnamedAddr())); @@ -1403,7 +1403,7 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { if (globalName.empty()) globalName = getOrCreateNamelessSymbolName(globalVar).getValue(); - GlobalOp globalOp = builder.create( + GlobalOp globalOp = GlobalOp::create(builder, mlirModule.getLoc(), type, globalVar->isConstant(), convertLinkageFromLLVM(globalVar->getLinkage()), StringRef(globalName), valueAttr, alignment, /*addr_space=*/globalVar->getAddressSpace(), @@ -1420,7 +1420,7 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { convertConstantExpr(globalVar->getInitializer()); if (failed(initializer)) return failure(); - builder.create(globalOp.getLoc(), *initializer); + ReturnOp::create(builder, globalOp.getLoc(), *initializer); } if (globalVar->hasAtLeastLocalUnnamedAddr()) { globalOp.setUnnamedAddr( @@ -1488,12 +1488,12 @@ ModuleImport::convertGlobalCtorsAndDtors(llvm::GlobalVariable *globalVar) { OpBuilder::InsertionGuard guard = setGlobalInsertionPoint(); if (globalVar->getName() == getGlobalCtorsVarName()) { - globalInsertionOp = builder.create( + globalInsertionOp = LLVM::GlobalCtorsOp::create(builder, mlirModule.getLoc(), builder.getArrayAttr(funcs), builder.getI32ArrayAttr(priorities), builder.getArrayAttr(dataList)); return success(); } - globalInsertionOp = builder.create( + globalInsertionOp = LLVM::GlobalDtorsOp::create(builder, mlirModule.getLoc(), builder.getArrayAttr(funcs), builder.getI32ArrayAttr(priorities), builder.getArrayAttr(dataList)); return success(); @@ -1569,33 +1569,33 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { if (Attribute attr = getConstantAsAttr(constant)) { Type type = convertType(constant->getType()); if (auto symbolRef = dyn_cast(attr)) { - return builder.create(loc, type, symbolRef.getValue()) + return AddressOfOp::create(builder, loc, type, symbolRef.getValue()) .getResult(); } - return builder.create(loc, type, attr).getResult(); + return ConstantOp::create(builder, loc, type, attr).getResult(); } // Convert null pointer constants. if (auto *nullPtr = dyn_cast(constant)) { Type type = convertType(nullPtr->getType()); - return builder.create(loc, type).getResult(); + return ZeroOp::create(builder, loc, type).getResult(); } // Convert none token constants. if (isa(constant)) { - return builder.create(loc).getResult(); + return NoneTokenOp::create(builder, loc).getResult(); } // Convert poison. if (auto *poisonVal = dyn_cast(constant)) { Type type = convertType(poisonVal->getType()); - return builder.create(loc, type).getResult(); + return PoisonOp::create(builder, loc, type).getResult(); } // Convert undef. if (auto *undefVal = dyn_cast(constant)) { Type type = convertType(undefVal->getType()); - return builder.create(loc, type).getResult(); + return UndefOp::create(builder, loc, type).getResult(); } // Convert dso_local_equivalent. @@ -1621,7 +1621,7 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { getOrCreateNamelessSymbolName(cast(globalObj)); else symbolRef = FlatSymbolRefAttr::get(context, globalName); - return builder.create(loc, type, symbolRef).getResult(); + return AddressOfOp::create(builder, loc, type, symbolRef).getResult(); } // Convert global alias accesses. @@ -1629,7 +1629,7 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { Type type = convertType(globalAliasObj->getType()); StringRef aliaseeName = globalAliasObj->getName(); FlatSymbolRefAttr symbolRef = FlatSymbolRefAttr::get(context, aliaseeName); - return builder.create(loc, type, symbolRef).getResult(); + return AddressOfOp::create(builder, loc, type, symbolRef).getResult(); } // Convert constant expressions. @@ -1680,15 +1680,15 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { bool isArrayOrStruct = isa(rootType); assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) && "unrecognized aggregate type"); - Value root = builder.create(loc, rootType); + Value root = UndefOp::create(builder, loc, rootType); for (const auto &it : llvm::enumerate(elementValues)) { if (isArrayOrStruct) { - root = builder.create(loc, root, it.value(), it.index()); + root = InsertValueOp::create(builder, loc, root, it.value(), it.index()); } else { Attribute indexAttr = builder.getI32IntegerAttr(it.index()); Value indexValue = - builder.create(loc, builder.getI32Type(), indexAttr); - root = builder.create(loc, rootType, root, it.value(), + ConstantOp::create(builder, loc, builder.getI32Type(), indexAttr); + root = InsertElementOp::create(builder, loc, rootType, root, it.value(), indexValue); } } @@ -1702,7 +1702,7 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { "target extension type does not support zero-initialization"); // Create llvm.mlir.zero operation to represent zero-initialization of // target extension type. - return builder.create(loc, targetExtType).getRes(); + return LLVM::ZeroOp::create(builder, loc, targetExtType).getRes(); } if (auto *blockAddr = dyn_cast(constant)) { @@ -2124,7 +2124,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { } if (!brInst->isConditional()) { - auto brOp = builder.create(loc, succBlockArgs.front(), + auto brOp = LLVM::BrOp::create(builder, loc, succBlockArgs.front(), succBlocks.front()); mapNoResultOp(inst, brOp); return success(); @@ -2132,7 +2132,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { FailureOr condition = convertValue(brInst->getCondition()); if (failed(condition)) return failure(); - auto condBrOp = builder.create( + auto condBrOp = LLVM::CondBrOp::create(builder, loc, *condition, succBlocks.front(), succBlockArgs.front(), succBlocks.back(), succBlockArgs.back()); mapNoResultOp(inst, condBrOp); @@ -2166,7 +2166,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { caseBlocks[it.index()] = lookupBlock(succBB); } - auto switchOp = builder.create( + auto switchOp = SwitchOp::create(builder, loc, *condition, lookupBlock(defaultBB), defaultBlockArgs, caseValues, caseBlocks, caseOperandRefs); mapNoResultOp(inst, switchOp); @@ -2218,14 +2218,14 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // IR). Build the indirect call by passing an empty `callee` operand and // insert into `operands` to include the indirect call target. FlatSymbolRefAttr calleeSym = convertCalleeName(callInst); - Value indirectCallVal = builder.create( + Value indirectCallVal = LLVM::AddressOfOp::create(builder, loc, LLVM::LLVMPointerType::get(context), calleeSym); operands->insert(operands->begin(), indirectCallVal); } else { // Regular direct call using callee name. callee = convertCalleeName(callInst); } - CallOp callOp = builder.create(loc, *funcTy, callee, *operands); + CallOp callOp = CallOp::create(builder, loc, *funcTy, callee, *operands); if (failed(convertCallAttributes(callInst, callOp))) return failure(); @@ -2260,7 +2260,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { Type type = convertType(lpInst->getType()); auto lpOp = - builder.create(loc, type, lpInst->isCleanup(), operands); + LandingpadOp::create(builder, loc, type, lpInst->isCleanup(), operands); mapValue(inst, lpOp); return success(); } @@ -2310,7 +2310,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // IR). Build the indirect invoke by passing an empty `callee` operand and // insert into `operands` to include the indirect invoke target. FlatSymbolRefAttr calleeSym = convertCalleeName(invokeInst); - Value indirectInvokeVal = builder.create( + Value indirectInvokeVal = LLVM::AddressOfOp::create(builder, loc, LLVM::LLVMPointerType::get(context), calleeSym); operands->insert(operands->begin(), indirectInvokeVal); } else { @@ -2320,7 +2320,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Create the invoke operation. Normal destination block arguments will be // added later on to handle the case in which the operation result is // included in this list. - auto invokeOp = builder.create( + auto invokeOp = InvokeOp::create(builder, loc, *funcTy, calleeName, *operands, directNormalDest, ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs); @@ -2348,7 +2348,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // arguments (including the invoke operation's result). OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToStart(directNormalDest); - builder.create(loc, normalArgs, normalDest); + LLVM::BrOp::create(builder, loc, normalArgs, normalDest); } else { // If the invoke operation's result is not a block argument to the normal // destination block, just add the block arguments as usual. @@ -2382,7 +2382,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { } Type type = convertType(inst->getType()); - auto gepOp = builder.create( + auto gepOp = GEPOp::create(builder, loc, type, sourceElementType, *basePtr, indices, static_cast(gepInst->getNoWrapFlags().getRaw())); mapValue(inst, gepOp); @@ -2409,7 +2409,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { SmallVector succBlockArgsRange = llvm::to_vector_of(succBlockArgs); Location loc = translateLoc(inst->getDebugLoc()); - auto indBrOp = builder.create( + auto indBrOp = LLVM::IndirectBrOp::create(builder, loc, *basePtr, succBlockArgsRange, succBlocks); mapNoResultOp(inst, indBrOp); @@ -2854,7 +2854,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { builder.setInsertionPointToEnd(mlirModule.getBody()); Location loc = debugImporter->translateFuncLocation(func); - LLVMFuncOp funcOp = builder.create( + LLVMFuncOp funcOp = LLVMFuncOp::create(builder, loc, func->getName(), functionType, convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv); @@ -3032,11 +3032,11 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, Operation *op = llvm::TypeSwitch(dbgIntr) .Case([&](llvm::DbgDeclareInst *) { - return builder.create( + return LLVM::DbgDeclareOp::create(builder, loc, *argOperand, localVariableAttr, locationExprAttr); }) .Case([&](llvm::DbgValueInst *) { - return builder.create( + return LLVM::DbgValueOp::create(builder, loc, *argOperand, localVariableAttr, locationExprAttr); }); mapNoResultOp(dbgIntr, op); @@ -3082,7 +3082,7 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb, if (bb->hasAddressTaken()) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(block); - builder.create(block->getParentOp()->getLoc(), + BlockTagOp::create(builder, block->getParentOp()->getLoc(), BlockTagAttr::get(context, bb->getNumber())); } return success(); diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index 55d6a380d0bff..6e14b94e87407 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -42,22 +42,22 @@ static inline spirv::Opcode extractOpcode(uint32_t word) { Value spirv::Deserializer::getValue(uint32_t id) { if (auto constInfo = getConstant(id)) { // Materialize a `spirv.Constant` op at every use site. - return opBuilder.create(unknownLoc, constInfo->second, + return spirv::ConstantOp::create(opBuilder, unknownLoc, constInfo->second, constInfo->first); } if (auto varOp = getGlobalVariable(id)) { - auto addressOfOp = opBuilder.create( + auto addressOfOp = spirv::AddressOfOp::create(opBuilder, unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation())); return addressOfOp.getPointer(); } if (auto constOp = getSpecConstant(id)) { - auto referenceOfOp = opBuilder.create( + auto referenceOfOp = spirv::ReferenceOfOp::create(opBuilder, unknownLoc, constOp.getDefaultValue().getType(), SymbolRefAttr::get(constOp.getOperation())); return referenceOfOp.getReference(); } if (auto constCompositeOp = getSpecConstantComposite(id)) { - auto referenceOfOp = opBuilder.create( + auto referenceOfOp = spirv::ReferenceOfOp::create(opBuilder, unknownLoc, constCompositeOp.getType(), SymbolRefAttr::get(constCompositeOp.getOperation())); return referenceOfOp.getReference(); @@ -69,7 +69,7 @@ Value spirv::Deserializer::getValue(uint32_t id) { specConstOperationInfo->enclosedOpOperands); } if (auto undef = getUndefType(id)) { - return opBuilder.create(unknownLoc, undef); + return spirv::UndefOp::create(opBuilder, unknownLoc, undef); } return valueMap.lookup(id); } @@ -369,7 +369,7 @@ Deserializer::processOp(ArrayRef words) { interface.push_back(SymbolRefAttr::get(arg.getOperation())); wordIndex++; } - opBuilder.create( + spirv::EntryPointOp::create(opBuilder, unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), opBuilder.getArrayAttr(interface)); return success(); @@ -402,7 +402,7 @@ Deserializer::processOp(ArrayRef words) { attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); } auto values = opBuilder.getArrayAttr(attrListElems); - opBuilder.create( + spirv::ExecutionModeOp::create(opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode, values); return success(); @@ -441,7 +441,7 @@ Deserializer::processOp(ArrayRef operands) { arguments.push_back(value); } - auto opFunctionCall = opBuilder.create( + auto opFunctionCall = spirv::FunctionCallOp::create(opBuilder, unknownLoc, resultType, SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); @@ -518,7 +518,7 @@ Deserializer::processOp(ArrayRef words) { } Location loc = createFileLineColLoc(opBuilder); - opBuilder.create(loc, resultTypes, operands, attributes); + spirv::CopyMemoryOp::create(opBuilder, loc, resultTypes, operands, attributes); return success(); } @@ -549,7 +549,7 @@ LogicalResult Deserializer::processOp( operands.push_back(arg); Location loc = createFileLineColLoc(opBuilder); - Operation *op = opBuilder.create( + Operation *op = spirv::GenericCastToPtrExplicitOp::create(opBuilder, loc, resultTypes, operands); valueMap[valueID] = op->getResult(0); return success(); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b1abd8b3dffe9..7b1f8055fadce 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -518,7 +518,7 @@ spirv::Deserializer::processFunction(ArrayRef operands) { } std::string fnName = getFunctionSymbol(fnID); - auto funcOp = opBuilder.create( + auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName, functionType, fnControl.value()); // Processing other function attributes. if (decorations.count(fnID)) { @@ -706,7 +706,7 @@ spirv::SpecConstantOp spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue) { auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); - auto op = opBuilder.create(unknownLoc, symName, + auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName, defaultValue); if (decorations.count(resultID)) { for (auto attr : decorations[resultID].getAttrs()) @@ -782,7 +782,7 @@ spirv::Deserializer::processGlobalVariable(ArrayRef operands) { << wordIndex << " of " << operands.size() << " processed"; } auto loc = createFileLineColLoc(opBuilder); - auto varOp = opBuilder.create( + auto varOp = spirv::GlobalVariableOp::create(opBuilder, loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), initializer); @@ -1581,7 +1581,7 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef operands) { elements.push_back(SymbolRefAttr::get(elementInfo)); } - auto op = opBuilder.create( + auto op = spirv::SpecConstantCompositeOp::create(opBuilder, unknownLoc, TypeAttr::get(resultType), symName, opBuilder.getArrayAttr(elements)); specConstCompositeMap[resultID] = op; @@ -1656,7 +1656,7 @@ Value spirv::Deserializer::materializeSpecConstantOperation( auto loc = createFileLineColLoc(opBuilder); auto specConstOperationOp = - opBuilder.create(loc, resultType); + spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType); Region &body = specConstOperationOp.getBody(); // Move the new block into SpecConstantOperation's body. @@ -1669,7 +1669,7 @@ Value spirv::Deserializer::materializeSpecConstantOperation( OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); opBuilder.setInsertionPointToEnd(&block); - opBuilder.create(loc, block.front().getResult(0)); + spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0)); return specConstOperationOp.getResult(); } @@ -1733,7 +1733,7 @@ LogicalResult spirv::Deserializer::processBranch(ArrayRef operands) { // The preceding instruction for the OpBranch instruction could be an // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have // the same OpLine information. - opBuilder.create(loc, target); + spirv::BranchOp::create(opBuilder, loc, target); clearDebugLine(); return success(); @@ -1764,7 +1764,7 @@ spirv::Deserializer::processBranchConditional(ArrayRef operands) { // an OpSelectionMerge instruction, in this case they will have the same // OpLine information. auto loc = createFileLineColLoc(opBuilder); - opBuilder.create( + spirv::BranchConditionalOp::create(opBuilder, loc, condition, trueBlock, /*trueArguments=*/ArrayRef(), falseBlock, /*falseArguments=*/ArrayRef(), weights); @@ -1947,7 +1947,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { OpBuilder builder(&mergeBlock->front()); auto control = static_cast(selectionControl); - auto selectionOp = builder.create(location, control); + auto selectionOp = spirv::SelectionOp::create(builder, location, control); selectionOp.addMergeBlock(builder); return selectionOp; @@ -1959,7 +1959,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { OpBuilder builder(&mergeBlock->front()); auto control = static_cast(loopControl); - auto loopOp = builder.create(location, control); + auto loopOp = spirv::LoopOp::create(builder, location, control); loopOp.addEntryAndMergeBlock(builder); return loopOp; @@ -2092,7 +2092,7 @@ LogicalResult ControlFlowStructurizer::structurize() { // The loop entry block should have a unconditional branch jumping to the // loop header block. builder.setInsertionPointToEnd(&body.front()); - builder.create(location, mapper.lookupOrNull(headerBlock), + spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock), ArrayRef(blockArgs)); } @@ -2177,11 +2177,11 @@ LogicalResult ControlFlowStructurizer::structurize() { Operation *newOp = nullptr; if (isLoop) - newOp = builder.create( + newOp = spirv::LoopOp::create(builder, location, TypeRange(ValueRange(outsideUses)), static_cast(control)); else - newOp = builder.create( + newOp = spirv::SelectionOp::create(builder, location, TypeRange(ValueRange(outsideUses)), static_cast(control)); @@ -2308,7 +2308,7 @@ LogicalResult ControlFlowStructurizer::structurize() { // but replace all ops inside with a branch to the merge block. block->clear(); builder.setInsertionPointToEnd(block); - builder.create(location, mergeBlock); + spirv::BranchOp::create(builder, location, mergeBlock); } else { LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n"); block->erase(); @@ -2362,7 +2362,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { if (auto branchOp = dyn_cast(op)) { // Replace the previous branch op with a new one with block arguments. - opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), + spirv::BranchOp::create(opBuilder, branchOp.getLoc(), branchOp.getTarget(), blockArgs); branchOp.erase(); } else if (auto branchCondOp = dyn_cast(op)) { @@ -2370,13 +2370,13 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { branchCondOp.getFalseBlock() == target) && "expected target to be either the true or false target"); if (target == branchCondOp.getTrueTarget()) - opBuilder.create( + spirv::BranchConditionalOp::create(opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs, branchCondOp.getFalseBlockArguments(), branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(), branchCondOp.getFalseTarget()); else - opBuilder.create( + spirv::BranchConditionalOp::create(opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(), branchCondOp.getTrueBlockArguments(), blockArgs, branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(), @@ -2437,7 +2437,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) { Block *newBlock = block->splitBlock(terminator); OpBuilder builder(block, block->end()); - builder.create(block->getParent()->getLoc(), newBlock); + spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock); // After splitting we need to update the map to use the new block as a // header. diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index 824201d17b5ab..3f4e7762fed0e 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -134,7 +134,7 @@ class CodeGen { OwningOpRef CodeGen::generate(const ast::Module &module) { OwningOpRef mlirModule = - builder.create(genLoc(module.getLoc())); + ModuleOp::create(builder, genLoc(module.getLoc())); builder.setInsertionPointToStart(mlirModule->getBody()); // Generate code for each of the decls within the module. @@ -205,7 +205,7 @@ static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc) { if (isa(builder.getInsertionBlock()->getParentOp())) { pdl::RewriteOp rewrite = - builder.create(loc, rootExpr, /*name=*/StringAttr(), + pdl::RewriteOp::create(builder, loc, rootExpr, /*name=*/StringAttr(), /*externalArgs=*/ValueRange()); builder.createBlock(&rewrite.getBodyRegion()); } @@ -219,7 +219,7 @@ void CodeGen::genImpl(const ast::EraseStmt *stmt) { // Make sure we are nested in a RewriteOp. OpBuilder::InsertionGuard guard(builder); checkAndNestUnderRewriteOp(builder, rootExpr, loc); - builder.create(loc, rootExpr); + pdl::EraseOp::create(builder, loc, rootExpr); } void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); } @@ -242,7 +242,7 @@ void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { bool usesReplOperation = replValues.size() == 1 && isa(replValues.front().getType()); - builder.create( + pdl::ReplaceOp::create(builder, loc, rootExpr, usesReplOperation ? replValues[0] : Value(), usesReplOperation ? ValueRange() : ValueRange(replValues)); } @@ -283,7 +283,7 @@ void CodeGen::genImpl(const ast::PatternDecl *decl) { // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it // here. - pdl::PatternOp pattern = builder.create( + pdl::PatternOp pattern = pdl::PatternOp::create(builder, genLoc(decl->getLoc()), decl->getBenefit(), name ? std::optional(name->getName()) : std::optional()); @@ -338,19 +338,19 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, ast::Type type = varDecl->getType(); Type mlirType = genType(type); if (isa(type)) - return builder.create(loc, mlirType, getTypeConstraint()); + return pdl::OperandOp::create(builder, loc, mlirType, getTypeConstraint()); if (isa(type)) - return builder.create(loc, mlirType, /*type=*/TypeAttr()); + return pdl::TypeOp::create(builder, loc, mlirType, /*type=*/TypeAttr()); if (isa(type)) - return builder.create(loc, getTypeConstraint()); + return pdl::AttributeOp::create(builder, loc, getTypeConstraint()); if (ast::OperationType opType = dyn_cast(type)) { - Value operands = builder.create( + Value operands = pdl::OperandsOp::create(builder, loc, pdl::RangeType::get(builder.getType()), /*type=*/Value()); - Value results = builder.create( + Value results = pdl::TypesOp::create(builder, loc, pdl::RangeType::get(builder.getType()), /*types=*/ArrayAttr()); - return builder.create(loc, opType.getName(), operands, + return pdl::OperationOp::create(builder, loc, opType.getName(), operands, ArrayRef(), ValueRange(), results); } @@ -358,10 +358,10 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, if (ast::RangeType rangeTy = dyn_cast(type)) { ast::Type eleTy = rangeTy.getElementType(); if (isa(eleTy)) - return builder.create(loc, mlirType, + return pdl::OperandsOp::create(builder, loc, mlirType, getTypeConstraint()); if (isa(eleTy)) - return builder.create(loc, mlirType, /*types=*/ArrayAttr()); + return pdl::TypesOp::create(builder, loc, mlirType, /*types=*/ArrayAttr()); } llvm_unreachable("invalid non-initialized variable type"); @@ -404,7 +404,7 @@ SmallVector CodeGen::genExpr(const ast::Expr *expr) { Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { Attribute attr = parseAttribute(expr->getValue(), builder.getContext()); assert(attr && "invalid MLIR attribute data"); - return builder.create(genLoc(expr->getLoc()), attr); + return pdl::AttributeOp::create(builder, genLoc(expr->getLoc()), attr); } SmallVector CodeGen::genExprImpl(const ast::CallExpr *expr) { @@ -443,9 +443,9 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { if (isa(expr)) { Type mlirType = genType(expr->getType()); if (isa(mlirType)) - return builder.create(loc, mlirType, parentExprs[0], + return pdl::ResultOp::create(builder, loc, mlirType, parentExprs[0], builder.getI32IntegerAttr(0)); - return builder.create(loc, mlirType, parentExprs[0]); + return pdl::ResultsOp::create(builder, loc, mlirType, parentExprs[0]); } const ods::Operation *odsOp = opType.getODSOperation(); @@ -455,7 +455,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { unsigned resultIndex; name.getAsInteger(/*Radix=*/10, resultIndex); IntegerAttr index = builder.getI32IntegerAttr(resultIndex); - return builder.create(loc, genType(expr->getType()), + return pdl::ResultOp::create(builder, loc, genType(expr->getType()), parentExprs[0], index); } @@ -474,7 +474,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { // Generate the result access. IntegerAttr index = builder.getI32IntegerAttr(resultIndex); - return builder.create(loc, genType(expr->getType()), + return pdl::ResultsOp::create(builder, loc, genType(expr->getType()), parentExprs[0], index); } @@ -518,7 +518,7 @@ Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { for (const ast::Expr *result : expr->getResultTypes()) results.push_back(genSingleExpr(result)); - return builder.create(loc, opName, operands, attrNames, + return pdl::OperationOp::create(builder, loc, opName, operands, attrNames, attrValues, results); } @@ -527,7 +527,7 @@ Value CodeGen::genExprImpl(const ast::RangeExpr *expr) { for (const ast::Expr *element : expr->getElements()) llvm::append_range(elements, genExpr(element)); - return builder.create(genLoc(expr->getLoc()), + return pdl::RangeOp::create(builder, genLoc(expr->getLoc()), genType(expr->getType()), elements); } @@ -541,7 +541,7 @@ SmallVector CodeGen::genExprImpl(const ast::TupleExpr *expr) { Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { Type type = parseType(expr->getValue(), builder.getContext()); assert(type && "invalid MLIR type data"); - return builder.create(genLoc(expr->getLoc()), + return pdl::TypeOp::create(builder, genLoc(expr->getLoc()), builder.getType(), TypeAttr::get(type)); } @@ -586,7 +586,7 @@ CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, } else { resultTypes.push_back(genType(declResultType)); } - PDLOpT pdlOp = builder.create(loc, resultTypes, + PDLOpT pdlOp = PDLOpT::create(builder, loc, resultTypes, decl->getName().getName(), inputs); if (isNegated && std::is_same_v) cast(pdlOp).setIsNegated(true); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d118fe422f2f2..68b8335ed9262 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1505,7 +1505,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = - builder.create(loc, outputTypes, inputs); + UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); if (!valuesToMap.empty()) mapping.map(valuesToMap, convertOp.getResults()); if (castOp) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 535f5e9b4a15d..73366110b104e 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1192,7 +1192,7 @@ def OpFuncRef : TEST_Op<"op_funcref"> { let description = [{ The "test.op_funcref" is a test op with a reference to a function symbol. }]; - let builders = [OpBuilder<(ins "::mlir::func::FuncOp":$function)>]; + // let builders = [OpBuilder<(ins "::mlir::func::FuncOp":$function)>]; } // Pattern add the argument plus a increasing static number hidden in diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index cdf44c2959d50..9753ac00af88a 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index ee3eb9522db7e..44f117e93da6b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" #include "llvm/Support/Debug.h" diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp index 9a5632bb99c06..3d7b6be8a78e9 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.cpp +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/RegionUtils.h" diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp index a0ff31504c691..e0a76f1b75d38 100644 --- a/mlir/test/python/lib/PythonTestDialect.cpp +++ b/mlir/test/python/lib/PythonTestDialect.cpp @@ -9,6 +9,7 @@ #include "PythonTestDialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/TypeSwitch.h" #include "PythonTestDialect.cpp.inc" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index cbb4030f3adb4..f35cfa6826388 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -230,6 +230,18 @@ static const char *const opCommentHeader = R"( )"; +static const char *const inlineCreateBody = R"( + ::mlir::OperationState __state__({0}, getOperationName()); + build(builder, __state__{1}); + auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__)); + assert(__res__ && "builder didn't return the right type"); + return __res__; +)"; + +static const char *const inlineCreateBodyImplicitLoc = R"( + return create(builder, builder.getLoc(){0}); +)"; + //===----------------------------------------------------------------------===// // Utility structs and functions //===----------------------------------------------------------------------===// @@ -665,6 +677,7 @@ class OpEmitter { // Generates the build() method that takes each operand/attribute // as a stand-alone parameter. void genSeparateArgParamBuilder(); + void genInlineCreateBody(const SmallVector ¶mList); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first operand's @@ -2568,6 +2581,51 @@ static bool canInferType(const Operator &op) { return op.getTrait("::mlir::InferTypeOpInterface::Trait"); } +void OpEmitter::genInlineCreateBody( + const SmallVector ¶mList) { + SmallVector createParamListOpBuilder; + SmallVector createParamListImplicitLocOpBuilder; + SmallVector nonBuilderStateArgsList; + createParamListOpBuilder.emplace_back("::mlir::OpBuilder &", "builder"); + createParamListImplicitLocOpBuilder.emplace_back( + "::mlir::ImplicitLocOpBuilder &", "builder"); + std::string locParamName = "location"; + while (llvm::find_if(paramList, [&locParamName](const MethodParameter &p) { + return p.getName() == locParamName; + }) != paramList.end()) { + locParamName += "_"; + } + createParamListOpBuilder.emplace_back("::mlir::Location", locParamName); + + for (auto ¶m : paramList) { + if (param.getType() == "::mlir::OpBuilder &" || + param.getType() == "::mlir::OperationState &") + continue; + createParamListOpBuilder.emplace_back(param.getType(), param.getName(), + param.getDefaultValue(), + param.isOptional()); + createParamListImplicitLocOpBuilder.emplace_back( + param.getType(), param.getName(), param.getDefaultValue(), + param.isOptional()); + nonBuilderStateArgsList.push_back(param.getName()); + } + auto *cWithLoc = opClass.addStaticMethod(opClass.getClassName(), "create", + createParamListOpBuilder); + auto *cImplicitLoc = opClass.addStaticMethod( + opClass.getClassName(), "create", createParamListImplicitLocOpBuilder); + std::string nonBuilderStateArgs = ""; + if (!nonBuilderStateArgsList.empty()) { + llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs); + interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS); + nonBuilderStateArgs = ", " + nonBuilderStateArgs; + } + cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName, + nonBuilderStateArgs, + opClass.getClassName()); + cImplicitLoc->body() << llvm::formatv(inlineCreateBodyImplicitLoc, + nonBuilderStateArgs); +} + void OpEmitter::genSeparateArgParamBuilder() { SmallVector attrBuilderType; attrBuilderType.push_back(AttrParamKind::WrappedAttr); @@ -2584,10 +2642,12 @@ void OpEmitter::genSeparateArgParamBuilder() { buildParamList(paramList, inferredAttributes, resultNames, paramKind, attrType); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", paramList); // If the builder is redundant, skip generating the method. if (!m) return; + genInlineCreateBody(paramList); + auto &body = m->body(); genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, /*isRawValueAttr=*/attrType == @@ -2712,10 +2772,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder( if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", paramList); // If the builder is redundant, skip generating the method if (!m) return; + genInlineCreateBody(paramList); auto &body = m->body(); // Operands @@ -2826,10 +2887,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder( if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", paramList); // If the builder is redundant, skip generating the method if (!m) return; + genInlineCreateBody(paramList); auto &body = m->body(); int numResults = op.getNumResults(); @@ -2906,10 +2968,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { buildParamList(paramList, inferredAttributes, resultNames, TypeParamKind::None, attrType); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", paramList); // If the builder is redundant, skip generating the method if (!m) return; + genInlineCreateBody(paramList); auto &body = m->body(); genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, /*isRawValueAttr=*/attrType == @@ -2948,10 +3011,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder( : "attributes"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", attributesName, "{}"); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", paramList); // If the builder is redundant, skip generating the method if (!m) return; + genInlineCreateBody(paramList); auto &body = m->body(); @@ -3039,8 +3103,7 @@ void OpEmitter::genBuilder() { std::optional body = builder.getBody(); auto properties = body ? Method::Static : Method::StaticDeclaration; - auto *method = - opClass.addMethod("void", "build", properties, std::move(arguments)); + auto *method = opClass.addMethod("void", "build", properties, arguments); if (body) ERROR_IF_PRUNED(method, "build", op); @@ -3052,6 +3115,7 @@ void OpEmitter::genBuilder() { fctx.addSubst("_state", builderOpState); if (body) method->body() << tgfmt(*body, &fctx); + genInlineCreateBody(arguments); } // Generate default builders that requires all result type, operands, and @@ -3114,10 +3178,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) { if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); - auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", paramList); // If the builder is redundant, skip generating the method if (!m) return; + genInlineCreateBody(paramList); auto &body = m->body(); // Operands