@@ -158,29 +158,41 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
158
158
return *this ;
159
159
}
160
160
161
- // / Pad `opOperand` using the provided `paddingValues`. Exit early for scalar
162
- // / operands, if `paddingValues` contains no value for the `opOperand `, or if
163
- // / `opOperand` is not defined by an ExtractSliceOp. Otherwise, try to pad the
164
- // / operand even if it already has a static shape. Set `result` to the result of
165
- // / the created tensor::PadOp or and return success if the operand either has
166
- // / been padded to a static shape or already had a static shape and failure
167
- // / otherwise.
168
- static LogicalResult padOperandToSmallestStaticBoundingBox (
161
+ // / Pad the `opOperand` in the `paddingDimensions` using the padding value and
162
+ // / the nofold flag found in `paddingValues` and `packPaddings `, respectively.
163
+ // / Exit early and return the `opOperand` value if the shape dimensions that
164
+ // / match `paddingDimensions` have a static size and the nofold flag is not set.
165
+ // / Otherwise, try to pad the shape dimensions that match the iterator
166
+ // / dimensions `paddingDimensions` and return the tensor::PadOp result if
167
+ // / padding succeeds or failure otherwise.
168
+ static FailureOr<Value> padOperandToSmallestStaticBoundingBox (
169
169
OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170
- ArrayRef<Attribute> paddingValues, ArrayRef<bool > packPaddings,
171
- Value &result) {
172
- // Get the shape of the operand and check if it has a dynamic shape. Only
173
- // return failure if the operand is not a scalar and has a dynamic shape.
170
+ ArrayRef<int64_t > paddingDimensions, ArrayRef<Attribute> paddingValues,
171
+ ArrayRef<bool > packPaddings) {
172
+ AffineMap indexingMap = opToPad.getTiedIndexingMap (opOperand);
174
173
ArrayRef<int64_t > shape = opToPad.getShape (opOperand);
175
- bool hasDynamicShape = llvm::is_contained (shape, ShapedType::kDynamicSize );
176
174
177
- // Cannot pad scalar operands.
178
- if (shape.empty ())
179
- return success ();
175
+ // Collect the shape dimension that are a function of the `paddingDimensions`.
176
+ llvm::SmallDenseSet<int64_t > shapeDimsToPad;
177
+ for (int64_t dim : paddingDimensions)
178
+ for (const auto &en : enumerate(indexingMap.getResults ()))
179
+ if (en.value ().isFunctionOfDim (dim))
180
+ shapeDimsToPad.insert (en.index ());
180
181
181
- // Cannot pad if the padding value is unknown.
182
+ // Return the unpadded operand if padding to a static shape is not needed and
183
+ // if the nofold flag is not set.
184
+ bool nofold = opOperand->getOperandNumber () < packPaddings.size ()
185
+ ? packPaddings[opOperand->getOperandNumber ()]
186
+ : false ;
187
+ bool hasStaticShape = llvm::none_of (shapeDimsToPad, [&](int64_t dim) {
188
+ return ShapedType::isDynamic (shape[dim]);
189
+ });
190
+ if (!nofold && hasStaticShape)
191
+ return opOperand->get ();
192
+
193
+ // Fail if `paddingValues` specifies no padding value.
182
194
if (opOperand->getOperandNumber () >= paddingValues.size ())
183
- return failure (hasDynamicShape );
195
+ return failure ();
184
196
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber ()];
185
197
Value paddingValue = b.create <arith::ConstantOp>(
186
198
opToPad.getLoc (), paddingAttr.getType (), paddingAttr);
@@ -192,27 +204,31 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
192
204
currOpOperand = linalgOp.getOutputOperand (result.getResultNumber ());
193
205
}
194
206
195
- // Cannot construct a static bounding box if the `currOpOperand` is not
196
- // defined by an ExtractSliceOp.
207
+ // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
197
208
auto sliceOp = currOpOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
198
209
if (!sliceOp)
199
- return failure (hasDynamicShape );
210
+ return failure ();
200
211
201
212
// Compute the dropped dimensions if `sliceOp` is ranke-reducing.
202
213
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims ();
214
+ OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
203
215
204
216
// Upper bound the `sliceOp` sizes to obtain a static bounding box.
205
- SmallVector<int64_t > staticSizes;
206
- staticSizes.reserve (shape.size ());
207
- auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation ());
217
+ SmallVector<int64_t > paddedShape (shape.begin (), shape.end ());
218
+ int64_t shapeIdx = 0 ;
208
219
for (const auto &en : enumerate(shapedOp.getMixedSizes ())) {
209
220
// Skip dropped dimensions.
210
221
if (droppedDims.test (en.index ()))
211
222
continue ;
212
- // If the size is an attribute add it directly to `staticSizes`.
223
+ // Skip dimensions that do not require padding.
224
+ if (!shapeDimsToPad.contains (shapeIdx)) {
225
+ shapeIdx++;
226
+ continue ;
227
+ }
228
+ // If the size is an attribute add it directly to `paddedShape`.
213
229
if (en.value ().is <Attribute>()) {
214
- staticSizes. push_back (
215
- en.value ().get <Attribute>().dyn_cast <IntegerAttr>().getInt ()) ;
230
+ paddedShape[shapeIdx++] =
231
+ en.value ().get <Attribute>().dyn_cast <IntegerAttr>().getInt ();
216
232
continue ;
217
233
}
218
234
// Otherwise, try to compute a constant upper bound for the size value.
@@ -222,24 +238,21 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
222
238
LLVM_DEBUG (DBGS () << " No constant bounding box can be found for padding" );
223
239
return failure ();
224
240
}
225
- staticSizes. push_back ( upperBound.getValue () );
241
+ paddedShape[shapeIdx++] = upperBound.getValue ();
226
242
}
227
- assert (staticSizes. size () == shape.size () &&
243
+ assert (shapeIdx == static_cast < int64_t >( shape.size () ) &&
228
244
" expect the dynamic and static ranks to match" );
229
245
230
- // Pad the operand to the bounding box defined by `staticSizes`.
231
- auto staticTensorType = RankedTensorType::get (
232
- staticSizes, getElementTypeOrSelf (opOperand->get ()));
233
- bool nofold = opOperand->getOperandNumber () < packPaddings.size ()
234
- ? packPaddings[opOperand->getOperandNumber ()]
235
- : false ;
236
- result = makeComposedPadHighOp (b, opToPad->getLoc (), staticTensorType,
237
- opOperand->get (), paddingValue, nofold);
238
- return success ();
246
+ // Pad the operand to the bounding box defined by `paddedShape`.
247
+ auto paddedTensorType = RankedTensorType::get (
248
+ paddedShape, getElementTypeOrSelf (opOperand->get ()));
249
+ return makeComposedPadHighOp (b, opToPad->getLoc (), paddedTensorType,
250
+ opOperand->get (), paddingValue, nofold);
239
251
}
240
252
241
253
FailureOr<SmallVector<Value>>
242
254
linalg::rewriteAsPaddedOp (OpBuilder &b, LinalgOp opToPad,
255
+ ArrayRef<int64_t > paddingDimensions,
243
256
ArrayRef<Attribute> paddingValues,
244
257
ArrayRef<bool > packPaddings, LinalgOp &paddedOp) {
245
258
Location loc = opToPad->getLoc ();
@@ -255,13 +268,12 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255
268
SmallVector<Value> newOperands;
256
269
newOperands.reserve (opToPad.getNumInputsAndOutputs ());
257
270
for (OpOperand *opOperand : opToPad.getInputAndOutputOperands ()) {
258
- Value paddedOperand;
259
- // If padding was requested but the shape cannot be bounded statically then
260
- // the pattern fails to apply.
261
- if (failed (padOperandToSmallestStaticBoundingBox (
262
- b, opToPad, opOperand, paddingValues, packPaddings, paddedOperand)))
271
+ FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox (
272
+ b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273
+ // Exit if `paddingDimensions` cannot be bounded statically.
274
+ if (failed (paddedOperand))
263
275
return failure ();
264
- newOperands.push_back (paddedOperand ? paddedOperand : opOperand-> get () );
276
+ newOperands.push_back (* paddedOperand);
265
277
}
266
278
267
279
SmallVector<SmallVector<Value>> reifiedResultShapes;
@@ -502,19 +514,25 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
502
514
// Pad the operation.
503
515
LinalgOp paddedOp;
504
516
FailureOr<SmallVector<Value>> newResults =
505
- rewriteAsPaddedOp (rewriter, linalgOp, options.paddingValues ,
506
- options.packPaddings , paddedOp);
517
+ rewriteAsPaddedOp (rewriter, linalgOp, options.paddingDimensions ,
518
+ options.paddingValues , options. packPaddings , paddedOp);
507
519
if (failed (newResults))
508
520
return failure ();
509
521
510
522
// Hoist the padding.
511
523
for (const auto &en : enumerate(options.hoistPaddings )) {
512
524
if (static_cast <int64_t >(en.index ()) >= paddedOp.getNumInputsAndOutputs ())
513
525
break ;
514
- OpOperand & opOperand = paddedOp->getOpOperand (en.index ());
515
- auto padOp = opOperand. get ().getDefiningOp <tensor::PadOp>();
526
+ OpOperand * opOperand = & paddedOp->getOpOperand (en.index ());
527
+ auto padOp = opOperand-> get ().getDefiningOp <tensor::PadOp>();
516
528
if (!padOp || en.value () == 0 )
517
529
continue ;
530
+
531
+ // Fail hoisting if the operand shape is not fully static.
532
+ if (llvm::any_of (paddedOp.getShape (opOperand),
533
+ [](int64_t size) { return ShapedType::isDynamic (size); }))
534
+ return failure ();
535
+
518
536
tensor::PadOp hoistedOp;
519
537
SmallVector<GenericOp> transposeOps;
520
538
SmallVector<int64_t > transposeVector =
0 commit comments