Skip to content

Commit 58d0da8

Browse files
author
gysit
committed
[mlir][linalg] Use arrays to pass padding options.
Pass the padding options using arrays instead of lambdas. In particular pass the padding value as string and use the argument parser to create the padding value. Arrays are a more natural choice that matches the current use cases and avoids converting arrays to lambdas. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D122309
1 parent 4df69c1 commit 58d0da8

File tree

7 files changed

+110
-161
lines changed

7 files changed

+110
-161
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 23 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -553,69 +553,32 @@ void transformIndexOps(RewriterBase &b, LinalgOp op,
553553
SmallVectorImpl<Value> &ivs,
554554
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
555555

556-
/// Callback returning the padding value to use for a given OpOperand or failure
557-
/// for no padding. This should be a function of both the operation and the
558-
/// operand type.
559-
using PaddingValueComputationFunction =
560-
std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>;
561-
562-
/// Callback returning true if the PadOp defining the given OpOperand shall be
563-
/// marked as nofold to enable packing.
564-
using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
565-
566-
/// Callback returning the number of loops to hoist the PadOp defining the given
567-
/// OpOperand.
568-
using PaddingHoistComputationFunction = std::function<int64_t(OpOperand &)>;
569-
570-
/// Callback returning the transpose vector used to permute the result tensor
571-
/// dimensions of the PadOp defining the given OpOperand.
572-
using PaddingTransposeComputationFunction =
573-
std::function<SmallVector<int64_t>(OpOperand &)>;
574-
575556
struct LinalgPaddingOptions {
576-
/// Callback returning the padding value to use for a given OpOperand or
577-
/// failure for no padding. Padding operations are introduced if
578-
/// `paddingValueComputationFunction` is set and does not return failure.
579-
/// Padding all operands guarantees the operation is statically shaped and
580-
/// thus can be vectorized.
581-
PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
582-
583-
LinalgPaddingOptions &
584-
setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
585-
paddingValueComputationFunction = std::move(fun);
557+
/// A padding value for every operand.
558+
SmallVector<Attribute> paddingValues;
559+
LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) {
560+
paddingValues.assign(pv.begin(), pv.end());
586561
return *this;
587562
}
588-
589-
/// Callback returning true if the PadOp defining the given OpOperand shall be
590-
/// marked as nofold to enable packing. A padding operation is only marked
591-
/// nofold if `paddingNoFoldComputationFunction` is set and returns true.
592-
/// Otherwise, the nofold attribute is set to false.
593-
PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
594-
595-
LinalgPaddingOptions &
596-
setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
597-
paddingNoFoldComputationFunction = std::move(fun);
563+
/// A flag for every operand to mark the PadOp as nofold which enables packing
564+
/// for statically shaped operands.
565+
SmallVector<bool> packPaddings;
566+
LinalgPaddingOptions &setPackPaddings(ArrayRef<bool> pp) {
567+
packPaddings.assign(pp.begin(), pp.end());
598568
return *this;
599569
}
600-
601-
/// Callback returning the number of loops to hoist the PadOp defining the
602-
/// given OpOperand.
603-
PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr;
604-
605-
LinalgPaddingOptions &
606-
setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) {
607-
paddingHoistComputationFunction = std::move(fun);
570+
/// A number of loops to hoist the PadOp out for every operand.
571+
SmallVector<int64_t> hoistPaddings;
572+
LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) {
573+
hoistPaddings.assign(hp.begin(), hp.end());
608574
return *this;
609575
}
610-
611-
/// Callback returning the transpose vector used to permute the result tensor
612-
/// dimensions of the PadOp defining the given OpOperand.
613-
PaddingTransposeComputationFunction paddingTransposeComputationFunction =
614-
nullptr;
615-
616-
LinalgPaddingOptions &setPaddingTransposeComputationFunction(
617-
PaddingTransposeComputationFunction fun) {
618-
paddingTransposeComputationFunction = std::move(fun);
576+
/// A permutation vector for every operand used to transpose the packed PadOp
577+
/// results.
578+
SmallVector<SmallVector<int64_t>> transposePaddings;
579+
LinalgPaddingOptions &
580+
setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) {
581+
transposePaddings.assign(tp.begin(), tp.end());
619582
return *this;
620583
}
621584
};
@@ -1254,16 +1217,15 @@ struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> {
12541217
PatternRewriter &rewriter) const override;
12551218
};
12561219

1257-
/// Pad the operands of `opToPad` to a static bounding box. Use `paddingFunc`
1258-
/// and `nofoldFunc` to set the padding value and the nofold attribute of the
1220+
/// Pad the operands of `opToPad` to a static bounding box. Use `paddingValues`
1221+
/// and `packPaddings` to set the padding value and the nofold attribute of the
12591222
/// introduced tensor::PadOps, respectively. Update `paddedOp` to the cloned
12601223
/// statically shaped operation and return the extracted dynamically shaped
12611224
/// results. If padding fails, return failure.
12621225
FailureOr<SmallVector<Value>>
12631226
rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
1264-
const PaddingValueComputationFunction &paddingFunc,
1265-
const PaddingNoFoldComputationFunction &nofoldFunc,
1266-
LinalgOp &paddedOp);
1227+
ArrayRef<Attribute> paddingValues,
1228+
ArrayRef<bool> packPaddings, LinalgOp &paddedOp);
12671229

12681230
using OptimizeCopyFn =
12691231
std::function<LogicalResult(PatternRewriter &, tensor::PadOp, Value)>;

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,17 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
158158
return *this;
159159
}
160160

161-
/// Helper function that tries to pad `opOperand`. Exit early for scalar
162-
/// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
163-
/// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
164-
/// has a static shape. Set `result` to the result of the created tensor::PadOp
165-
/// or and return success if the operand either has been padded to a static
166-
/// shape or already had a static shape and failure otherwise.
161+
/// Pad `opOperand` using the provided `paddingValues`. Exit early for scalar
162+
/// operands, if `paddingValues` contains no value for the `opOperand`, or if
163+
/// `opOperand` is not defined by an ExtractSliceOp. Otherwise, try to pad the
164+
/// operand even if it already has a static shape. Set `result` to the result of
165+
/// the created tensor::PadOp or and return success if the operand either has
166+
/// been padded to a static shape or already had a static shape and failure
167+
/// otherwise.
167168
static LogicalResult padOperandToSmallestStaticBoundingBox(
168169
OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
169-
const PaddingValueComputationFunction &paddingFunc,
170-
const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
170+
ArrayRef<Attribute> paddingValues, ArrayRef<bool> packPaddings,
171+
Value &result) {
171172
// Get the shape of the operand and check if it has a dynamic shape. Only
172173
// return failure if the operand is not a scalar and has a dynamic shape.
173174
ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
@@ -178,9 +179,11 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
178179
return success();
179180

180181
// Cannot pad if the padding value is unknown.
181-
FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
182-
if (failed(paddingValue))
182+
if (opOperand->getOperandNumber() >= paddingValues.size())
183183
return failure(hasDynamicShape);
184+
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
185+
Value paddingValue = b.create<arith::ConstantOp>(
186+
opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
184187

185188
// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
186189
OpOperand *currOpOperand = opOperand;
@@ -227,18 +230,18 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
227230
// Pad the operand to the bounding box defined by `staticSizes`.
228231
auto staticTensorType = RankedTensorType::get(
229232
staticSizes, getElementTypeOrSelf(opOperand->get()));
230-
bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
231-
result =
232-
makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
233-
opOperand->get(), paddingValue.getValue(), nofold);
233+
bool nofold = opOperand->getOperandNumber() < packPaddings.size()
234+
? packPaddings[opOperand->getOperandNumber()]
235+
: false;
236+
result = makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
237+
opOperand->get(), paddingValue, nofold);
234238
return success();
235239
}
236240

237241
FailureOr<SmallVector<Value>>
238242
linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
239-
const PaddingValueComputationFunction &paddingFunc,
240-
const PaddingNoFoldComputationFunction &nofoldFunc,
241-
LinalgOp &paddedOp) {
243+
ArrayRef<Attribute> paddingValues,
244+
ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
242245
Location loc = opToPad->getLoc();
243246

244247
// TODO: there are cases where we may still want to pad to larger sizes.
@@ -256,7 +259,7 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
256259
// If padding was requested but the shape cannot be bounded statically then
257260
// the pattern fails to apply.
258261
if (failed(padOperandToSmallestStaticBoundingBox(
259-
b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
262+
b, opToPad, opOperand, paddingValues, packPaddings, paddedOperand)))
260263
return failure();
261264
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
262265
}
@@ -498,29 +501,26 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
498501

499502
// Pad the operation.
500503
LinalgOp paddedOp;
501-
FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
502-
rewriter, linalgOp, options.paddingValueComputationFunction,
503-
options.paddingNoFoldComputationFunction, paddedOp);
504+
FailureOr<SmallVector<Value>> newResults =
505+
rewriteAsPaddedOp(rewriter, linalgOp, options.paddingValues,
506+
options.packPaddings, paddedOp);
504507
if (failed(newResults))
505508
return failure();
506509

507-
// Compute the desired hoisting depths.
508-
SmallVector<int64_t> depths;
509-
if (options.paddingHoistComputationFunction) {
510-
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
511-
depths.push_back(options.paddingHoistComputationFunction(*opOperand));
512-
}
513-
514510
// Hoist the padding.
515-
for (const auto &en : enumerate(depths)) {
511+
for (const auto &en : enumerate(options.hoistPaddings)) {
512+
if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
513+
break;
516514
OpOperand &opOperand = paddedOp->getOpOperand(en.index());
517515
auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
518516
if (!padOp || en.value() == 0)
519517
continue;
520518
tensor::PadOp hoistedOp;
521519
SmallVector<GenericOp> transposeOps;
522520
SmallVector<int64_t> transposeVector =
523-
options.paddingTransposeComputationFunction(opOperand);
521+
en.index() < options.transposePaddings.size()
522+
? options.transposePaddings[en.index()]
523+
: SmallVector<int64_t>{};
524524

525525
FailureOr<Value> newResult = hoistPaddingOnTensors(
526526
padOp, en.value(), transposeVector, hoistedOp, transposeOps);

mlir/test/Dialect/Linalg/codegen-strategy.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" -split-input-file | FileCheck %s --check-prefix=CHECK-INTRINSIC
22
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" -split-input-file | FileCheck %s --check-prefix=CHECK-OUTER
33
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 tile-interchange=1,2,0 generalize iterator-interchange=0,2,1" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE
4-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD
5-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE
6-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP
4+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad padding-values=0.:f32,0.:f32,0.:f32 pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD
5+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad padding-values=0.:f32,0.:f32,0.:f32 vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE
6+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad padding-values=0.:f32,0.:f32,0.:f32 decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP
77

88
// CHECK-INTRINSIC: func @matmul(
99
// CHECK-OUTER: func @matmul(
10-
func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<72x72xf32>) {
10+
func.func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<72x72xf32>) {
1111

1212
// Check the matrix intrinsic lowering is triggered.
1313
// CHECK-INTRINSIC: vector.matrix_multiply
@@ -17,13 +17,13 @@ func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<7
1717
// Check the outer product lowering is triggered.
1818
// CHECK-OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
1919
linalg.matmul ins(%arg0, %arg1: memref<72x72xf32>, memref<72x72xf32>) outs(%arg2: memref<72x72xf32>)
20-
return
20+
func.return
2121
}
2222

2323
// -----
2424

2525
// CHECK-INTERCHANGE: func @matmul(
26-
func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
26+
func.func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
2727
// CHECK-INTERCHANGE-DAG: %[[C16:.*]] = arith.constant 16
2828
// CHECK-INTERCHANGE-DAG: %[[C32:.*]] = arith.constant 32
2929
// CHECK-INTERCHANGE-DAG: %[[C64:.*]] = arith.constant 64
@@ -37,15 +37,15 @@ func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<7
3737
// CHECK-INTERCHANGE: linalg.generic
3838
// CHECK-INTERCHANGE-SAME: iterator_types = ["parallel", "reduction", "parallel"]
3939
%0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32>
40-
return %0 : tensor<72x72xf32>
40+
func.return %0 : tensor<72x72xf32>
4141
}
4242

4343
// -----
4444

4545
// CHECK-PAD-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (-d0 + 72, 16)>
4646

4747
// CHECK-PAD: func @matmul(
48-
func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
48+
func.func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
4949

5050
// Check the padding of the input operands has been hoisted out of the tile loop nest.
5151
// CHECK-PAD-COUNT=2: tensor.pad %{{.*}} nofold
@@ -56,13 +56,13 @@ func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<7
5656
// CHECK-PAD-COUNT=2: scf.for
5757
// CHECK-PAD: linalg.matmul
5858
%0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32>
59-
return %0 : tensor<72x72xf32>
59+
func.return %0 : tensor<72x72xf32>
6060
}
6161

6262
// -----
6363

6464
// CHECK-FUSE: func @matmul(
65-
func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
65+
func.func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
6666

6767
// Check the padding and vectorization applies to the fill operation due to the empty anchor op string.
6868
// CHECK-FUSE: %[[CST:.*]] = arith.constant dense<0.000000e+00>
@@ -73,13 +73,13 @@ func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<7
7373
// Check the matmul is padded and vectorized despite the empty anchor op string.
7474
// CHECK-FUSE: vector.outerproduct
7575
%1 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%0: tensor<72x72xf32>) -> tensor<72x72xf32>
76-
return %1 : tensor<72x72xf32>
76+
func.return %1 : tensor<72x72xf32>
7777
}
7878

7979
// -----
8080

8181
// CHECK-DECOMP: func @conv(
82-
func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> {
82+
func.func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> {
8383
%cst = arith.constant 0.000000e+00 : f32
8484
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32>
8585

@@ -88,5 +88,5 @@ func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: t
8888
// CHECK-DECOMP: vector.outerproduct
8989
// CHECK-DECOMP: vector.transfer_write {{.*}}: vector<1x8x32xf32>, tensor<1x1x?x32xf32>
9090
%1 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<8x18x17x32xf32>, tensor<3x3x32x64xf32>) outs(%0 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32>
91-
return %1 : tensor<8x16x15x64xf32>
91+
func.return %1 : tensor<8x16x15x64xf32>
9292
}

0 commit comments

Comments
 (0)