Skip to content

Commit bff2a99

Browse files
[TOSA] Fix output size calculation for pool ops (#4125)
TOSA requires (inputDim + padBefore + padAfter - kernel) to be fully divisible by stride. This update adds pad and input size modifications for pooling ops (AvgPool2d and MaxPool2d) to satisfy that requirement by TOSA. Signed-off-by: Justin Ngo <justin.ngo@arm.com>
1 parent 5eae636 commit bff2a99

File tree

4 files changed

+597
-43
lines changed

4 files changed

+597
-43
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5766,19 +5766,69 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57665766
op, "Unimplemented pooling input parsing function");
57675767
}
57685768

5769-
static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim,
5770-
int64_t stride, int64_t padBefore,
5771-
int64_t padAfter, int64_t dilation,
5769+
static int64_t getOutputDim(PatternRewriter &rewriter, Value &input,
5770+
Location loc, int64_t inputRank,
5771+
ArrayRef<int64_t> inputShape, Type inputElemTy,
5772+
int64_t dimIndex, int64_t kernelDim,
5773+
int64_t stride, int64_t &padBefore,
5774+
int64_t &padAfter, int64_t dilation,
57725775
bool ceilMode = false) {
5776+
int64_t inputDim = inputShape[dimIndex];
57735777
if (inputDim == kUnknownSize) {
57745778
return kUnknownSize;
57755779
} else {
5780+
// TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5781+
// be fully divisible by stride. We would have to modify the after pad
5782+
// and/ input in order to achieve that.
5783+
// Note: The dimSize calculation below is the same as TOSA's dimSize
5784+
// calculation when dilation = 1, which is the only dilation value that
5785+
// TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5786+
// value will be defaulted to 1)
57765787
int64_t dimSize =
57775788
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
5789+
int64_t remainderDim = dimSize % stride;
5790+
5791+
// When PyTorch uses floor mode for output dim calculation, to achieve the
5792+
// TOSA's divisibility requirement, we will remove the unused after pad
5793+
// and slice the unused input rows/columns.
5794+
if (!ceilMode && (remainderDim != 0)) {
5795+
if (remainderDim > padAfter) {
5796+
SmallVector<int64_t> startSlice(inputRank, 0);
5797+
// In cases where we have to do 2 slice operations (one for height and
5798+
// one for width), we need to use the new sliced shape before doing
5799+
// the second slice, not the original inputShape. Therefore, the shape
5800+
// needs to be retrieved again here.
5801+
SmallVector<int64_t> sizeSlice(
5802+
dyn_cast<TensorType>(input.getType()).getShape());
5803+
sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5804+
input = rewriter.create<tosa::SliceOp>(
5805+
loc, RankedTensorType::get(sizeSlice, inputElemTy), input,
5806+
tosa::getTosaConstShape(rewriter, loc, startSlice),
5807+
tosa::getTosaConstShape(rewriter, loc, sizeSlice));
5808+
dimSize = dimSize - padAfter;
5809+
padAfter = 0;
5810+
} else {
5811+
dimSize = dimSize - padAfter;
5812+
padAfter = padAfter - remainderDim;
5813+
dimSize = dimSize + padAfter;
5814+
}
5815+
}
5816+
57785817
int64_t outputDim = dimSize / stride + 1;
5779-
if (ceilMode && (dimSize % stride != 0) &&
5780-
(outputDim * stride < inputDim + padBefore))
5781-
outputDim++;
5818+
5819+
// When PyTorch uses ceil mode for output dim calculation, to achieve the
5820+
// TOSA's divisibility requirement, we will remove the unused after pad
5821+
// or add more after pad in case the remainder is more than the after pad
5822+
if (ceilMode && (remainderDim != 0)) {
5823+
if (remainderDim < padAfter) {
5824+
padAfter = padAfter - remainderDim;
5825+
} else {
5826+
padAfter = padAfter + (stride - remainderDim);
5827+
}
5828+
5829+
if (outputDim * stride < inputDim + padBefore)
5830+
outputDim++;
5831+
}
57825832
return outputDim;
57835833
}
57845834
}
@@ -6016,25 +6066,24 @@ class ConvertAtenAdaptivePoolingOp
60166066

60176067
template <typename AtenOpT, typename tosaOp>
60186068
static Type getOutputTypeForNonAdaptivePoolingOp(
6069+
PatternRewriter &rewriter, Operation *op, Value &input,
60196070
RankedTensorType inputTy, SmallVectorImpl<int64_t> &kernelSize,
60206071
SmallVectorImpl<int64_t> &strideArray, SmallVectorImpl<int64_t> &padArray,
60216072
SmallVectorImpl<int64_t> &dilationArray, bool ceilMode = false) {
60226073
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
60236074
auto inputRank = inputTy.getRank();
60246075
auto inputElemTy = inputTy.getElementType();
60256076

6077+
// PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6078+
// rank-1
60266079
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
6027-
inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0],
6028-
padArray[0], dilationArray[0], ceilMode);
6080+
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy,
6081+
/*dimIndex=*/inputRank - 2, kernelSize[0], strideArray[0], padArray[0],
6082+
padArray[1], dilationArray[0], ceilMode);
60296083
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
6030-
inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1],
6031-
padArray[1], dilationArray[1], ceilMode);
6032-
padArray[0] = (outputHDim - 1) * strideArray[0] +
6033-
dilationArray[0] * kernelSize[0] - dilationArray[0] + 1 -
6034-
padArray[0] * 2 - inputShape[inputRank - 2];
6035-
padArray[1] = (outputWDim - 1) * strideArray[1] +
6036-
dilationArray[0] * kernelSize[1] - dilationArray[0] + 1 -
6037-
padArray[1] * 2 - inputShape[inputRank - 1];
6084+
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy,
6085+
/*dimIndex=*/inputRank - 1, kernelSize[1], strideArray[1], padArray[2],
6086+
padArray[3], dilationArray[1], ceilMode);
60386087
SmallVector<int64_t> outputShape;
60396088
if (inputRank > 3)
60406089
outputShape.push_back(inputShape[0]);
@@ -6065,7 +6114,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
60656114
// vector. Also, gets the output type for the pooling op.
60666115
template <typename AtenOpT, typename tosaOp>
60676116
static LogicalResult getOutputTypeAndPoolingParameters(
6068-
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6117+
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
60696118
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
60706119
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
60716120
DenseI64ArrayAttr &pad) {
@@ -6138,10 +6187,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61386187

61396188
expandPoolParams(op, dilationArray, 1);
61406189
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6141-
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6142-
ceilMode);
6143-
padArr[1] = padArr[1] + paddingInts[0];
6144-
padArr[3] = padArr[3] + paddingInts[1];
6190+
rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6191+
dilationArray, ceilMode);
61456192
pad = rewriter.getDenseI64ArrayAttr(
61466193
{padArr[0], padArr[1], padArr[2], padArr[3]});
61476194
return success();
@@ -6157,6 +6204,7 @@ class ConvertAtenMaxPool2dOp
61576204
DenseI64ArrayAttr &kernel,
61586205
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
61596206
Type &outputTy) const override {
6207+
auto self = adaptor.getSelf();
61606208
SmallVector<int64_t, 2> dilationArray;
61616209
if (!matchPattern(op.getDilation(),
61626210
m_TorchListOfConstantInts(dilationArray)))
@@ -6169,14 +6217,13 @@ class ConvertAtenMaxPool2dOp
61696217

61706218
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
61716219
tosa::MaxPool2dOp>(
6172-
op, rewriter, adaptor.getSelf(), dilationArray, outputTy, kernel,
6173-
stride, pad)))
6220+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
61746221
return rewriter.notifyMatchFailure(
61756222
op, "invalid pooling parameters or input type");
61766223

61776224
// Transpose to xHWC
61786225
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6179-
transposePoolingInputToHwc(op, rewriter, adaptor.getSelf());
6226+
transposePoolingInputToHwc(op, rewriter, self);
61806227

61816228
return success();
61826229
}
@@ -6210,11 +6257,15 @@ class ConvertAtenMaxPool1dOp
62106257
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
62116258
SmallVector<int64_t> rank4Shape(selfShape);
62126259
rank4Shape.push_back(1);
6213-
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
6214-
op->getLoc(),
6215-
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6216-
selfTy.getElementType()),
6217-
self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape));
6260+
auto reshapedSelf =
6261+
rewriter
6262+
.create<tosa::ReshapeOp>(
6263+
op->getLoc(),
6264+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6265+
selfTy.getElementType()),
6266+
self,
6267+
tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape))
6268+
.getResult();
62186269

62196270
SmallVector<int64_t> dilationArray;
62206271
if (!matchPattern(op.getDilation(),
@@ -6231,14 +6282,14 @@ class ConvertAtenMaxPool1dOp
62316282

62326283
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
62336284
tosa::MaxPool2dOp>(
6234-
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
6235-
kernel, stride, pad)))
6285+
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6286+
pad)))
62366287
return rewriter.notifyMatchFailure(
62376288
op, "invalid pooling parameters or input type");
62386289

62396290
// Transpose to xHWC
62406291
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6241-
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
6292+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
62426293

62436294
return success();
62446295
}
@@ -6254,6 +6305,7 @@ class ConvertAtenAvgPool2dOp
62546305
DenseI64ArrayAttr &kernel,
62556306
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
62566307
Type &outputTy) const override {
6308+
auto self = adaptor.getSelf();
62576309

62586310
// Currently, we can not represent `divisor_override` with the existing TOSA
62596311
// AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6267,14 +6319,13 @@ class ConvertAtenAvgPool2dOp
62676319
SmallVector<int64_t, 2> dilationArray{1, 1};
62686320
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
62696321
tosa::AvgPool2dOp>(
6270-
op, rewriter, adaptor.getSelf(), dilationArray, outputTy, kernel,
6271-
stride, pad)))
6322+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
62726323
return rewriter.notifyMatchFailure(
62736324
op, "invalid pooling parameters or input type");
62746325

62756326
// Transpose to xHWC
62766327
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6277-
transposePoolingInputToHwc(op, rewriter, adaptor.getSelf());
6328+
transposePoolingInputToHwc(op, rewriter, self);
62786329

62796330
return success();
62806331
}
@@ -6308,23 +6359,27 @@ class ConvertAtenAvgPool1dOp
63086359
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
63096360
SmallVector<int64_t> rank4Shape(selfShape);
63106361
rank4Shape.push_back(1);
6311-
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
6312-
op->getLoc(),
6313-
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6314-
selfTy.getElementType()),
6315-
self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape));
6362+
auto reshapedSelf =
6363+
rewriter
6364+
.create<tosa::ReshapeOp>(
6365+
op->getLoc(),
6366+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6367+
selfTy.getElementType()),
6368+
self,
6369+
tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape))
6370+
.getResult();
63166371

63176372
SmallVector<int64_t, 2> dilationArray{1, 1};
63186373
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
63196374
tosa::AvgPool2dOp>(
6320-
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
6321-
kernel, stride, pad)))
6375+
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6376+
pad)))
63226377
return rewriter.notifyMatchFailure(
63236378
op, "invalid pooling parameters or input type");
63246379

63256380
// Transpose to xHWC
63266381
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6327-
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
6382+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
63286383

63296384
return success();
63306385
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,8 @@
964964
"AtenSymConstrainRangeForSize_basic",
965965
"Aten_AssertScalar_basic",
966966
"NativeGroupNormModule_basic",
967+
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
968+
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
967969
}
968970

969971
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3300,6 +3302,9 @@
33003302
"Aten_AssertScalar_basic",
33013303
# JIT session error: Symbols not found: [ memrefCopy ]
33023304
"SplitWithSizes_Module_basic",
3305+
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
3306+
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3307+
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33033308
}
33043309

33053310
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

0 commit comments

Comments
 (0)