Skip to content

Commit d26c42a

Browse files
author
gysit
committed
[mlir][linalg] Control dimensions to pad.
This revision supports padding only a subset of the iteration dimensions via an additional padding-dimensions parameter. This control allows us to pad an operation in multiple steps. For example, one may want to pad only the output dimensions of a producer matmul fused into a consumer loop nest, before tiling and padding its reduction dimension. Depends On D122309 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D122560
1 parent a8c2770 commit d26c42a

File tree

8 files changed

+144
-79
lines changed

8 files changed

+144
-79
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,12 @@ struct LinalgPaddingOptions {
560560
paddingValues.assign(pv.begin(), pv.end());
561561
return *this;
562562
}
563+
/// A list of iterator dimensions to pad.
564+
SmallVector<int64_t> paddingDimensions;
565+
LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
566+
paddingDimensions.assign(pd.begin(), pd.end());
567+
return *this;
568+
}
563569
/// A flag for every operand to mark the PadOp as nofold which enables packing
564570
/// for statically shaped operands.
565571
SmallVector<bool> packPaddings;
@@ -1217,13 +1223,15 @@ struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> {
12171223
PatternRewriter &rewriter) const override;
12181224
};
12191225

1220-
/// Pad the operands of `opToPad` to a static bounding box. Use `paddingValues`
1221-
/// and `packPaddings` to set the padding value and the nofold attribute of the
1222-
/// introduced tensor::PadOps, respectively. Update `paddedOp` to the cloned
1223-
/// statically shaped operation and return the extracted dynamically shaped
1224-
/// results. If padding fails, return failure.
1226+
/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands to
1227+
/// a static bounding box. Use `paddingValues` and `packPaddings` to set padding
1228+
/// value and nofold attribute of the created tensor::PadOps, respectively.
1229+
/// Update `paddedOp` to the cloned operation with statically shaped
1230+
/// `paddingDimensions` and return the extracted dynamically shaped results. If
1231+
/// padding fails, return failure.
12251232
FailureOr<SmallVector<Value>>
12261233
rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
1234+
ArrayRef<int64_t> paddingDimensions,
12271235
ArrayRef<Attribute> paddingValues,
12281236
ArrayRef<bool> packPaddings, LinalgOp &paddedOp);
12291237

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ namespace mlir {
1515
namespace tensor {
1616

1717
// Return a PadOp that pads `source` to `type` size where the static
18-
// sizes are assumed to be greater than the dynamic sizes. The op performs
19-
// "high" padding (i.e. it adds trailing padding values until the desired
20-
// size is met).
21-
PadOp createPadHighOp(Type type, Value source, Value pad, bool nofold,
22-
Location loc, OpBuilder &builder);
18+
// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
19+
// dimensions the padding width is set to zero. The op performs "high" padding
20+
// (i.e. it adds trailing padding values until the desired size is met).
21+
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
22+
bool nofold, Location loc, OpBuilder &builder);
2323

2424
// Return a PadOp that pads `source to `type` size with `pad` value.
2525
// I.e., a block will be created and the `pad` value will be yielded

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -158,29 +158,41 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
158158
return *this;
159159
}
160160

161-
/// Pad `opOperand` using the provided `paddingValues`. Exit early for scalar
162-
/// operands, if `paddingValues` contains no value for the `opOperand`, or if
163-
/// `opOperand` is not defined by an ExtractSliceOp. Otherwise, try to pad the
164-
/// operand even if it already has a static shape. Set `result` to the result of
165-
/// the created tensor::PadOp or and return success if the operand either has
166-
/// been padded to a static shape or already had a static shape and failure
167-
/// otherwise.
168-
static LogicalResult padOperandToSmallestStaticBoundingBox(
161+
/// Pad the `opOperand` in the `paddingDimensions` using the padding value and
162+
/// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
163+
/// Exit early and return the `opOperand` value if the shape dimensions that
164+
/// match `paddingDimensions` have a static size and the nofold flag is not set.
165+
/// Otherwise, try to pad the shape dimensions that match the iterator
166+
/// dimensions `paddingDimensions` and return the tensor::PadOp result if
167+
/// padding succeeds or failure otherwise.
168+
static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
169169
OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170-
ArrayRef<Attribute> paddingValues, ArrayRef<bool> packPaddings,
171-
Value &result) {
172-
// Get the shape of the operand and check if it has a dynamic shape. Only
173-
// return failure if the operand is not a scalar and has a dynamic shape.
170+
ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
171+
ArrayRef<bool> packPaddings) {
172+
AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
174173
ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
175-
bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
176174

177-
// Cannot pad scalar operands.
178-
if (shape.empty())
179-
return success();
175+
// Collect the shape dimension that are a function of the `paddingDimensions`.
176+
llvm::SmallDenseSet<int64_t> shapeDimsToPad;
177+
for (int64_t dim : paddingDimensions)
178+
for (const auto &en : enumerate(indexingMap.getResults()))
179+
if (en.value().isFunctionOfDim(dim))
180+
shapeDimsToPad.insert(en.index());
180181

181-
// Cannot pad if the padding value is unknown.
182+
// Return the unpadded operand if padding to a static shape is not needed and
183+
// if the nofold flag is not set.
184+
bool nofold = opOperand->getOperandNumber() < packPaddings.size()
185+
? packPaddings[opOperand->getOperandNumber()]
186+
: false;
187+
bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
188+
return ShapedType::isDynamic(shape[dim]);
189+
});
190+
if (!nofold && hasStaticShape)
191+
return opOperand->get();
192+
193+
// Fail if `paddingValues` specifies no padding value.
182194
if (opOperand->getOperandNumber() >= paddingValues.size())
183-
return failure(hasDynamicShape);
195+
return failure();
184196
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
185197
Value paddingValue = b.create<arith::ConstantOp>(
186198
opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
@@ -192,27 +204,31 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
192204
currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
193205
}
194206

195-
// Cannot construct a static bounding box if the `currOpOperand` is not
196-
// defined by an ExtractSliceOp.
207+
// Fail if `currOpOperand` is not defined by an ExtractSliceOp.
197208
auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
198209
if (!sliceOp)
199-
return failure(hasDynamicShape);
210+
return failure();
200211

201212
// Compute the dropped dimensions if `sliceOp` is ranke-reducing.
202213
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
214+
OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
203215

204216
// Upper bound the `sliceOp` sizes to obtain a static bounding box.
205-
SmallVector<int64_t> staticSizes;
206-
staticSizes.reserve(shape.size());
207-
auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
217+
SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
218+
int64_t shapeIdx = 0;
208219
for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
209220
// Skip dropped dimensions.
210221
if (droppedDims.test(en.index()))
211222
continue;
212-
// If the size is an attribute add it directly to `staticSizes`.
223+
// Skip dimensions that do not require padding.
224+
if (!shapeDimsToPad.contains(shapeIdx)) {
225+
shapeIdx++;
226+
continue;
227+
}
228+
// If the size is an attribute add it directly to `paddedShape`.
213229
if (en.value().is<Attribute>()) {
214-
staticSizes.push_back(
215-
en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
230+
paddedShape[shapeIdx++] =
231+
en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
216232
continue;
217233
}
218234
// Otherwise, try to compute a constant upper bound for the size value.
@@ -222,24 +238,21 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
222238
LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
223239
return failure();
224240
}
225-
staticSizes.push_back(upperBound.getValue());
241+
paddedShape[shapeIdx++] = upperBound.getValue();
226242
}
227-
assert(staticSizes.size() == shape.size() &&
243+
assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
228244
"expect the dynamic and static ranks to match");
229245

230-
// Pad the operand to the bounding box defined by `staticSizes`.
231-
auto staticTensorType = RankedTensorType::get(
232-
staticSizes, getElementTypeOrSelf(opOperand->get()));
233-
bool nofold = opOperand->getOperandNumber() < packPaddings.size()
234-
? packPaddings[opOperand->getOperandNumber()]
235-
: false;
236-
result = makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
237-
opOperand->get(), paddingValue, nofold);
238-
return success();
246+
// Pad the operand to the bounding box defined by `paddedShape`.
247+
auto paddedTensorType = RankedTensorType::get(
248+
paddedShape, getElementTypeOrSelf(opOperand->get()));
249+
return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
250+
opOperand->get(), paddingValue, nofold);
239251
}
240252

241253
FailureOr<SmallVector<Value>>
242254
linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255+
ArrayRef<int64_t> paddingDimensions,
243256
ArrayRef<Attribute> paddingValues,
244257
ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
245258
Location loc = opToPad->getLoc();
@@ -255,13 +268,12 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255268
SmallVector<Value> newOperands;
256269
newOperands.reserve(opToPad.getNumInputsAndOutputs());
257270
for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
258-
Value paddedOperand;
259-
// If padding was requested but the shape cannot be bounded statically then
260-
// the pattern fails to apply.
261-
if (failed(padOperandToSmallestStaticBoundingBox(
262-
b, opToPad, opOperand, paddingValues, packPaddings, paddedOperand)))
271+
FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
272+
b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273+
// Exit if `paddingDimensions` cannot be bounded statically.
274+
if (failed(paddedOperand))
263275
return failure();
264-
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
276+
newOperands.push_back(*paddedOperand);
265277
}
266278

267279
SmallVector<SmallVector<Value>> reifiedResultShapes;
@@ -502,19 +514,25 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
502514
// Pad the operation.
503515
LinalgOp paddedOp;
504516
FailureOr<SmallVector<Value>> newResults =
505-
rewriteAsPaddedOp(rewriter, linalgOp, options.paddingValues,
506-
options.packPaddings, paddedOp);
517+
rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
518+
options.paddingValues, options.packPaddings, paddedOp);
507519
if (failed(newResults))
508520
return failure();
509521

510522
// Hoist the padding.
511523
for (const auto &en : enumerate(options.hoistPaddings)) {
512524
if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
513525
break;
514-
OpOperand &opOperand = paddedOp->getOpOperand(en.index());
515-
auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
526+
OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
527+
auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
516528
if (!padOp || en.value() == 0)
517529
continue;
530+
531+
// Fail hoisting if the operand shape is not fully static.
532+
if (llvm::any_of(paddedOp.getShape(opOperand),
533+
[](int64_t size) { return ShapedType::isDynamic(size); }))
534+
return failure();
535+
518536
tensor::PadOp hoistedOp;
519537
SmallVector<GenericOp> transposeOps;
520538
SmallVector<int64_t> transposeVector =

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,6 @@ tensor::ExtractSliceOp makeComposedExtractSliceOp(
322322

323323
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
324324
Value source, Value pad, bool nofold) {
325-
assert(type.hasStaticShape() && "expect tensor type to have static shape");
326-
327325
// Exit if `source` is not defined by an ExtractSliceOp.
328326
auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
329327
if (!sliceOp)

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ PadOp mlir::tensor::createPadScalarOp(Type type, Value source, Value pad,
2525
auto padTensorOp =
2626
builder.create<PadOp>(loc, type, source, low, high, nofold);
2727
int rank = padTensorOp.getResultType().getRank();
28-
SmallVector<Type, 4> blockArgTypes(rank, builder.getIndexType());
29-
SmallVector<Location, 4> blockArgLocs(rank, loc);
28+
SmallVector<Type> blockArgTypes(rank, builder.getIndexType());
29+
SmallVector<Location> blockArgLocs(rank, loc);
3030
auto &region = padTensorOp.region();
3131
// `builder.createBlock` changes the insertion point within the block. Create
3232
// a guard to reset the insertion point of the builder after it is destroyed.
@@ -36,19 +36,22 @@ PadOp mlir::tensor::createPadScalarOp(Type type, Value source, Value pad,
3636
return padTensorOp;
3737
}
3838

39-
PadOp mlir::tensor::createPadHighOp(Type type, Value source, Value pad,
40-
bool nofold, Location loc, OpBuilder &b) {
41-
SmallVector<OpFoldResult, 4> low, high;
42-
auto rankedTensorType = type.cast<RankedTensorType>();
43-
assert(rankedTensorType.hasStaticShape());
44-
for (const auto &en : enumerate(rankedTensorType.getShape())) {
39+
PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
40+
Value pad, bool nofold, Location loc,
41+
OpBuilder &b) {
42+
auto zero = b.createOrFold<arith::ConstantIndexOp>(loc, 0);
43+
SmallVector<OpFoldResult> low(type.getRank(), zero);
44+
SmallVector<OpFoldResult> high(type.getRank(), zero);
45+
for (const auto &en : enumerate(type.getShape())) {
46+
// Pad only the static dimensions of the result tensor type.
47+
if (ShapedType::isDynamic(en.value()))
48+
continue;
49+
// Compute the padding width.
4550
AffineExpr d0;
4651
bindDims(b.getContext(), d0);
4752
auto dimOp = b.createOrFold<tensor::DimOp>(loc, source, en.index());
48-
Value paddingWidth =
49-
makeComposedAffineApply(b, loc, en.value() - d0, {dimOp});
50-
high.push_back(paddingWidth);
51-
low.push_back(b.createOrFold<arith::ConstantIndexOp>(loc, 0));
53+
high[en.index()] =
54+
makeComposedAffineApply(b, loc, en.value() - d0, {dimOp}).getResult();
5255
}
5356
return createPadScalarOp(type, source, pad, low, high, nofold, loc, b);
5457
}

mlir/test/Dialect/Linalg/codegen-strategy.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" -split-input-file | FileCheck %s --check-prefix=CHECK-INTRINSIC
22
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" -split-input-file | FileCheck %s --check-prefix=CHECK-OUTER
33
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 tile-interchange=1,2,0 generalize iterator-interchange=0,2,1" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE
4-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad padding-values=0.:f32,0.:f32,0.:f32 pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD
5-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad padding-values=0.:f32,0.:f32,0.:f32 vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE
6-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad padding-values=0.:f32,0.:f32,0.:f32 decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP
4+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad padding-values=0.:f32,0.:f32,0.:f32 padding-dimensions=0,1,2 pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD
5+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad padding-values=0.:f32,0.:f32,0.:f32 padding-dimensions=0,1,2 vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE
6+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad padding-values=0.:f32,0.:f32,0.:f32 padding-dimensions=0,1,2 decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP
77

88
// CHECK-INTRINSIC: func @matmul(
99
// CHECK-OUTER: func @matmul(

0 commit comments

Comments
 (0)