Skip to content

Commit 16b7e33

Browse files
justin-ngo-armsahas3
authored andcommitted
[TOSA] Fix output size calculation for pool ops (llvm#4125)
TOSA requires (inputDim + padBefore + padAfter - kernel) to be fully divisible by stride. This update adds pad and input size modifications for pooling ops (AvgPool2d and MaxPool2d) to satisfy that requirement by TOSA. Signed-off-by: Justin Ngo <justin.ngo@arm.com>
1 parent 4d0e12c commit 16b7e33

File tree

4 files changed

+597
-43
lines changed

4 files changed

+597
-43
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5828,19 +5828,69 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
58285828
op, "Unimplemented pooling input parsing function");
58295829
}
58305830

5831-
static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim,
5832-
int64_t stride, int64_t padBefore,
5833-
int64_t padAfter, int64_t dilation,
5831+
static int64_t getOutputDim(PatternRewriter &rewriter, Value &input,
5832+
Location loc, int64_t inputRank,
5833+
ArrayRef<int64_t> inputShape, Type inputElemTy,
5834+
int64_t dimIndex, int64_t kernelDim,
5835+
int64_t stride, int64_t &padBefore,
5836+
int64_t &padAfter, int64_t dilation,
58345837
bool ceilMode = false) {
5838+
int64_t inputDim = inputShape[dimIndex];
58355839
if (inputDim == kUnknownSize) {
58365840
return kUnknownSize;
58375841
} else {
5842+
// TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5843+
// be fully divisible by stride. We would have to modify the after pad
5844+
// and/ input in order to achieve that.
5845+
// Note: The dimSize calculation below is the same as TOSA's dimSize
5846+
// calculation when dilation = 1, which is the only dilation value that
5847+
// TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5848+
// value will be defaulted to 1)
58385849
int64_t dimSize =
58395850
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
5851+
int64_t remainderDim = dimSize % stride;
5852+
5853+
// When PyTorch uses floor mode for output dim calculation, to achieve the
5854+
// TOSA's divisibility requirement, we will remove the unused after pad
5855+
// and slice the unused input rows/columns.
5856+
if (!ceilMode && (remainderDim != 0)) {
5857+
if (remainderDim > padAfter) {
5858+
SmallVector<int64_t> startSlice(inputRank, 0);
5859+
// In cases where we have to do 2 slice operations (one for height and
5860+
// one for width), we need to use the new sliced shape before doing
5861+
// the second slice, not the original inputShape. Therefore, the shape
5862+
// needs to be retrieved again here.
5863+
SmallVector<int64_t> sizeSlice(
5864+
dyn_cast<TensorType>(input.getType()).getShape());
5865+
sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5866+
input = rewriter.create<tosa::SliceOp>(
5867+
loc, RankedTensorType::get(sizeSlice, inputElemTy), input,
5868+
tosa::getTosaConstShape(rewriter, loc, startSlice),
5869+
tosa::getTosaConstShape(rewriter, loc, sizeSlice));
5870+
dimSize = dimSize - padAfter;
5871+
padAfter = 0;
5872+
} else {
5873+
dimSize = dimSize - padAfter;
5874+
padAfter = padAfter - remainderDim;
5875+
dimSize = dimSize + padAfter;
5876+
}
5877+
}
5878+
58405879
int64_t outputDim = dimSize / stride + 1;
5841-
if (ceilMode && (dimSize % stride != 0) &&
5842-
(outputDim * stride < inputDim + padBefore))
5843-
outputDim++;
5880+
5881+
// When PyTorch uses ceil mode for output dim calculation, to achieve the
5882+
// TOSA's divisibility requirement, we will remove the unused after pad
5883+
// or add more after pad in case the remainder is more than the after pad
5884+
if (ceilMode && (remainderDim != 0)) {
5885+
if (remainderDim < padAfter) {
5886+
padAfter = padAfter - remainderDim;
5887+
} else {
5888+
padAfter = padAfter + (stride - remainderDim);
5889+
}
5890+
5891+
if (outputDim * stride < inputDim + padBefore)
5892+
outputDim++;
5893+
}
58445894
return outputDim;
58455895
}
58465896
}
@@ -6078,25 +6128,24 @@ class ConvertAtenAdaptivePoolingOp
60786128

60796129
template <typename AtenOpT, typename tosaOp>
60806130
static Type getOutputTypeForNonAdaptivePoolingOp(
6131+
PatternRewriter &rewriter, Operation *op, Value &input,
60816132
RankedTensorType inputTy, SmallVectorImpl<int64_t> &kernelSize,
60826133
SmallVectorImpl<int64_t> &strideArray, SmallVectorImpl<int64_t> &padArray,
60836134
SmallVectorImpl<int64_t> &dilationArray, bool ceilMode = false) {
60846135
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
60856136
auto inputRank = inputTy.getRank();
60866137
auto inputElemTy = inputTy.getElementType();
60876138

6139+
// PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6140+
// rank-1
60886141
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
6089-
inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0],
6090-
padArray[0], dilationArray[0], ceilMode);
6142+
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy,
6143+
/*dimIndex=*/inputRank - 2, kernelSize[0], strideArray[0], padArray[0],
6144+
padArray[1], dilationArray[0], ceilMode);
60916145
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
6092-
inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1],
6093-
padArray[1], dilationArray[1], ceilMode);
6094-
padArray[0] = (outputHDim - 1) * strideArray[0] +
6095-
dilationArray[0] * kernelSize[0] - dilationArray[0] + 1 -
6096-
padArray[0] * 2 - inputShape[inputRank - 2];
6097-
padArray[1] = (outputWDim - 1) * strideArray[1] +
6098-
dilationArray[0] * kernelSize[1] - dilationArray[0] + 1 -
6099-
padArray[1] * 2 - inputShape[inputRank - 1];
6146+
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy,
6147+
/*dimIndex=*/inputRank - 1, kernelSize[1], strideArray[1], padArray[2],
6148+
padArray[3], dilationArray[1], ceilMode);
61006149
SmallVector<int64_t> outputShape;
61016150
if (inputRank > 3)
61026151
outputShape.push_back(inputShape[0]);
@@ -6127,7 +6176,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
61276176
// vector. Also, gets the output type for the pooling op.
61286177
template <typename AtenOpT, typename tosaOp>
61296178
static LogicalResult getOutputTypeAndPoolingParameters(
6130-
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6179+
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
61316180
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
61326181
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
61336182
DenseI64ArrayAttr &pad) {
@@ -6200,10 +6249,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
62006249

62016250
expandPoolParams(op, dilationArray, 1);
62026251
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6203-
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6204-
ceilMode);
6205-
padArr[1] = padArr[1] + paddingInts[0];
6206-
padArr[3] = padArr[3] + paddingInts[1];
6252+
rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6253+
dilationArray, ceilMode);
62076254
pad = rewriter.getDenseI64ArrayAttr(
62086255
{padArr[0], padArr[1], padArr[2], padArr[3]});
62096256
return success();
@@ -6219,6 +6266,7 @@ class ConvertAtenMaxPool2dOp
62196266
DenseI64ArrayAttr &kernel,
62206267
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
62216268
Type &outputTy) const override {
6269+
auto self = adaptor.getSelf();
62226270
SmallVector<int64_t, 2> dilationArray;
62236271
if (!matchPattern(op.getDilation(),
62246272
m_TorchListOfConstantInts(dilationArray)))
@@ -6231,14 +6279,13 @@ class ConvertAtenMaxPool2dOp
62316279

62326280
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
62336281
tosa::MaxPool2dOp>(
6234-
op, rewriter, adaptor.getSelf(), dilationArray, outputTy, kernel,
6235-
stride, pad)))
6282+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
62366283
return rewriter.notifyMatchFailure(
62376284
op, "invalid pooling parameters or input type");
62386285

62396286
// Transpose to xHWC
62406287
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6241-
transposePoolingInputToHwc(op, rewriter, adaptor.getSelf());
6288+
transposePoolingInputToHwc(op, rewriter, self);
62426289

62436290
return success();
62446291
}
@@ -6272,11 +6319,15 @@ class ConvertAtenMaxPool1dOp
62726319
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
62736320
SmallVector<int64_t> rank4Shape(selfShape);
62746321
rank4Shape.push_back(1);
6275-
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
6276-
op->getLoc(),
6277-
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6278-
selfTy.getElementType()),
6279-
self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape));
6322+
auto reshapedSelf =
6323+
rewriter
6324+
.create<tosa::ReshapeOp>(
6325+
op->getLoc(),
6326+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6327+
selfTy.getElementType()),
6328+
self,
6329+
tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape))
6330+
.getResult();
62806331

62816332
SmallVector<int64_t> dilationArray;
62826333
if (!matchPattern(op.getDilation(),
@@ -6293,14 +6344,14 @@ class ConvertAtenMaxPool1dOp
62936344

62946345
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
62956346
tosa::MaxPool2dOp>(
6296-
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
6297-
kernel, stride, pad)))
6347+
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6348+
pad)))
62986349
return rewriter.notifyMatchFailure(
62996350
op, "invalid pooling parameters or input type");
63006351

63016352
// Transpose to xHWC
63026353
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6303-
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
6354+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
63046355

63056356
return success();
63066357
}
@@ -6316,6 +6367,7 @@ class ConvertAtenAvgPool2dOp
63166367
DenseI64ArrayAttr &kernel,
63176368
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
63186369
Type &outputTy) const override {
6370+
auto self = adaptor.getSelf();
63196371

63206372
// Currently, we can not represent `divisor_override` with the existing TOSA
63216373
// AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6329,14 +6381,13 @@ class ConvertAtenAvgPool2dOp
63296381
SmallVector<int64_t, 2> dilationArray{1, 1};
63306382
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
63316383
tosa::AvgPool2dOp>(
6332-
op, rewriter, adaptor.getSelf(), dilationArray, outputTy, kernel,
6333-
stride, pad)))
6384+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
63346385
return rewriter.notifyMatchFailure(
63356386
op, "invalid pooling parameters or input type");
63366387

63376388
// Transpose to xHWC
63386389
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6339-
transposePoolingInputToHwc(op, rewriter, adaptor.getSelf());
6390+
transposePoolingInputToHwc(op, rewriter, self);
63406391

63416392
return success();
63426393
}
@@ -6370,23 +6421,27 @@ class ConvertAtenAvgPool1dOp
63706421
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
63716422
SmallVector<int64_t> rank4Shape(selfShape);
63726423
rank4Shape.push_back(1);
6373-
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
6374-
op->getLoc(),
6375-
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6376-
selfTy.getElementType()),
6377-
self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape));
6424+
auto reshapedSelf =
6425+
rewriter
6426+
.create<tosa::ReshapeOp>(
6427+
op->getLoc(),
6428+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6429+
selfTy.getElementType()),
6430+
self,
6431+
tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape))
6432+
.getResult();
63786433

63796434
SmallVector<int64_t, 2> dilationArray{1, 1};
63806435
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
63816436
tosa::AvgPool2dOp>(
6382-
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
6383-
kernel, stride, pad)))
6437+
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6438+
pad)))
63846439
return rewriter.notifyMatchFailure(
63856440
op, "invalid pooling parameters or input type");
63866441

63876442
// Transpose to xHWC
63886443
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6389-
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
6444+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
63906445

63916446
return success();
63926447
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,8 @@
966966
"AtenSymConstrainRangeForSize_basic",
967967
"Aten_AssertScalar_basic",
968968
"NativeGroupNormModule_basic",
969+
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
970+
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
969971
}
970972

971973
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3303,6 +3305,9 @@
33033305
"Aten_AssertScalar_basic",
33043306
# JIT session error: Symbols not found: [ memrefCopy ]
33053307
"SplitWithSizes_Module_basic",
3308+
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
3309+
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3310+
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33063311
}
33073312

33083313
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

0 commit comments

Comments
 (0)