@@ -5766,19 +5766,69 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5766
5766
op, " Unimplemented pooling input parsing function" );
5767
5767
}
5768
5768
5769
- static int64_t getOutputDim (int64_t inputDim, int64_t kernelDim,
5770
- int64_t stride, int64_t padBefore,
5771
- int64_t padAfter, int64_t dilation,
5769
+ static int64_t getOutputDim (PatternRewriter &rewriter, Value &input,
5770
+ Location loc, int64_t inputRank,
5771
+ ArrayRef<int64_t > inputShape, Type inputElemTy,
5772
+ int64_t dimIndex, int64_t kernelDim,
5773
+ int64_t stride, int64_t &padBefore,
5774
+ int64_t &padAfter, int64_t dilation,
5772
5775
bool ceilMode = false ) {
5776
+ int64_t inputDim = inputShape[dimIndex];
5773
5777
if (inputDim == kUnknownSize ) {
5774
5778
return kUnknownSize ;
5775
5779
} else {
5780
+ // TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5781
+ // be fully divisible by stride. We would have to modify the after pad
5782
+ // and/ input in order to achieve that.
5783
+ // Note: The dimSize calculation below is the same as TOSA's dimSize
5784
+ // calculation when dilation = 1, which is the only dilation value that
5785
+ // TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5786
+ // value will be defaulted to 1)
5776
5787
int64_t dimSize =
5777
5788
inputDim + padBefore + padAfter - dilation * (kernelDim - 1 ) - 1 ;
5789
+ int64_t remainderDim = dimSize % stride;
5790
+
5791
+ // When PyTorch uses floor mode for output dim calculation, to achieve the
5792
+ // TOSA's divisibility requirement, we will remove the unused after pad
5793
+ // and slice the unused input rows/columns.
5794
+ if (!ceilMode && (remainderDim != 0 )) {
5795
+ if (remainderDim > padAfter) {
5796
+ SmallVector<int64_t > startSlice (inputRank, 0 );
5797
+ // In cases where we have to do 2 slice operations (one for height and
5798
+ // one for width), we need to use the new sliced shape before doing
5799
+ // the second slice, not the original inputShape. Therefore, the shape
5800
+ // needs to be retrieved again here.
5801
+ SmallVector<int64_t > sizeSlice (
5802
+ dyn_cast<TensorType>(input.getType ()).getShape ());
5803
+ sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5804
+ input = rewriter.create <tosa::SliceOp>(
5805
+ loc, RankedTensorType::get (sizeSlice, inputElemTy), input,
5806
+ tosa::getTosaConstShape (rewriter, loc, startSlice),
5807
+ tosa::getTosaConstShape (rewriter, loc, sizeSlice));
5808
+ dimSize = dimSize - padAfter;
5809
+ padAfter = 0 ;
5810
+ } else {
5811
+ dimSize = dimSize - padAfter;
5812
+ padAfter = padAfter - remainderDim;
5813
+ dimSize = dimSize + padAfter;
5814
+ }
5815
+ }
5816
+
5778
5817
int64_t outputDim = dimSize / stride + 1 ;
5779
- if (ceilMode && (dimSize % stride != 0 ) &&
5780
- (outputDim * stride < inputDim + padBefore))
5781
- outputDim++;
5818
+
5819
+ // When PyTorch uses ceil mode for output dim calculation, to achieve the
5820
+ // TOSA's divisibility requirement, we will remove the unused after pad
5821
+ // or add more after pad in case the remainder is more than the after pad
5822
+ if (ceilMode && (remainderDim != 0 )) {
5823
+ if (remainderDim < padAfter) {
5824
+ padAfter = padAfter - remainderDim;
5825
+ } else {
5826
+ padAfter = padAfter + (stride - remainderDim);
5827
+ }
5828
+
5829
+ if (outputDim * stride < inputDim + padBefore)
5830
+ outputDim++;
5831
+ }
5782
5832
return outputDim;
5783
5833
}
5784
5834
}
@@ -6016,25 +6066,24 @@ class ConvertAtenAdaptivePoolingOp
6016
6066
6017
6067
template <typename AtenOpT, typename tosaOp>
6018
6068
static Type getOutputTypeForNonAdaptivePoolingOp (
6069
+ PatternRewriter &rewriter, Operation *op, Value &input,
6019
6070
RankedTensorType inputTy, SmallVectorImpl<int64_t > &kernelSize,
6020
6071
SmallVectorImpl<int64_t > &strideArray, SmallVectorImpl<int64_t > &padArray,
6021
6072
SmallVectorImpl<int64_t > &dilationArray, bool ceilMode = false ) {
6022
6073
auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
6023
6074
auto inputRank = inputTy.getRank ();
6024
6075
auto inputElemTy = inputTy.getElementType ();
6025
6076
6077
+ // PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6078
+ // rank-1
6026
6079
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6027
- inputShape[inputRank - 2 ], kernelSize[0 ], strideArray[0 ], padArray[0 ],
6028
- padArray[0 ], dilationArray[0 ], ceilMode);
6080
+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6081
+ /* dimIndex=*/ inputRank - 2 , kernelSize[0 ], strideArray[0 ], padArray[0 ],
6082
+ padArray[1 ], dilationArray[0 ], ceilMode);
6029
6083
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6030
- inputShape[inputRank - 1 ], kernelSize[1 ], strideArray[1 ], padArray[1 ],
6031
- padArray[1 ], dilationArray[1 ], ceilMode);
6032
- padArray[0 ] = (outputHDim - 1 ) * strideArray[0 ] +
6033
- dilationArray[0 ] * kernelSize[0 ] - dilationArray[0 ] + 1 -
6034
- padArray[0 ] * 2 - inputShape[inputRank - 2 ];
6035
- padArray[1 ] = (outputWDim - 1 ) * strideArray[1 ] +
6036
- dilationArray[0 ] * kernelSize[1 ] - dilationArray[0 ] + 1 -
6037
- padArray[1 ] * 2 - inputShape[inputRank - 1 ];
6084
+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6085
+ /* dimIndex=*/ inputRank - 1 , kernelSize[1 ], strideArray[1 ], padArray[2 ],
6086
+ padArray[3 ], dilationArray[1 ], ceilMode);
6038
6087
SmallVector<int64_t > outputShape;
6039
6088
if (inputRank > 3 )
6040
6089
outputShape.push_back (inputShape[0 ]);
@@ -6065,7 +6114,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> ¶ms,
6065
6114
// vector. Also, gets the output type for the pooling op.
6066
6115
template <typename AtenOpT, typename tosaOp>
6067
6116
static LogicalResult getOutputTypeAndPoolingParameters (
6068
- AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6117
+ AtenOpT op, ConversionPatternRewriter &rewriter, Value & inputXchw,
6069
6118
SmallVectorImpl<int64_t > &dilationArray, Type &outputTy,
6070
6119
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6071
6120
DenseI64ArrayAttr &pad) {
@@ -6138,10 +6187,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
6138
6187
6139
6188
expandPoolParams (op, dilationArray, 1 );
6140
6189
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6141
- inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6142
- ceilMode);
6143
- padArr[1 ] = padArr[1 ] + paddingInts[0 ];
6144
- padArr[3 ] = padArr[3 ] + paddingInts[1 ];
6190
+ rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6191
+ dilationArray, ceilMode);
6145
6192
pad = rewriter.getDenseI64ArrayAttr (
6146
6193
{padArr[0 ], padArr[1 ], padArr[2 ], padArr[3 ]});
6147
6194
return success ();
@@ -6157,6 +6204,7 @@ class ConvertAtenMaxPool2dOp
6157
6204
DenseI64ArrayAttr &kernel,
6158
6205
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
6159
6206
Type &outputTy) const override {
6207
+ auto self = adaptor.getSelf ();
6160
6208
SmallVector<int64_t , 2 > dilationArray;
6161
6209
if (!matchPattern (op.getDilation (),
6162
6210
m_TorchListOfConstantInts (dilationArray)))
@@ -6169,14 +6217,13 @@ class ConvertAtenMaxPool2dOp
6169
6217
6170
6218
if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
6171
6219
tosa::MaxPool2dOp>(
6172
- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6173
- stride, pad)))
6220
+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6174
6221
return rewriter.notifyMatchFailure (
6175
6222
op, " invalid pooling parameters or input type" );
6176
6223
6177
6224
// Transpose to xHWC
6178
6225
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6179
- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6226
+ transposePoolingInputToHwc (op, rewriter, self );
6180
6227
6181
6228
return success ();
6182
6229
}
@@ -6210,11 +6257,15 @@ class ConvertAtenMaxPool1dOp
6210
6257
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
6211
6258
SmallVector<int64_t > rank4Shape (selfShape);
6212
6259
rank4Shape.push_back (1 );
6213
- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6214
- op->getLoc (),
6215
- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6216
- selfTy.getElementType ()),
6217
- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6260
+ auto reshapedSelf =
6261
+ rewriter
6262
+ .create <tosa::ReshapeOp>(
6263
+ op->getLoc (),
6264
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6265
+ selfTy.getElementType ()),
6266
+ self,
6267
+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6268
+ .getResult ();
6218
6269
6219
6270
SmallVector<int64_t > dilationArray;
6220
6271
if (!matchPattern (op.getDilation (),
@@ -6231,14 +6282,14 @@ class ConvertAtenMaxPool1dOp
6231
6282
6232
6283
if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
6233
6284
tosa::MaxPool2dOp>(
6234
- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6235
- kernel, stride, pad)))
6285
+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6286
+ pad)))
6236
6287
return rewriter.notifyMatchFailure (
6237
6288
op, " invalid pooling parameters or input type" );
6238
6289
6239
6290
// Transpose to xHWC
6240
6291
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6241
- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6292
+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6242
6293
6243
6294
return success ();
6244
6295
}
@@ -6254,6 +6305,7 @@ class ConvertAtenAvgPool2dOp
6254
6305
DenseI64ArrayAttr &kernel,
6255
6306
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
6256
6307
Type &outputTy) const override {
6308
+ auto self = adaptor.getSelf ();
6257
6309
6258
6310
// Currently, we can not represent `divisor_override` with the existing TOSA
6259
6311
// AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6267,14 +6319,13 @@ class ConvertAtenAvgPool2dOp
6267
6319
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6268
6320
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
6269
6321
tosa::AvgPool2dOp>(
6270
- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6271
- stride, pad)))
6322
+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6272
6323
return rewriter.notifyMatchFailure (
6273
6324
op, " invalid pooling parameters or input type" );
6274
6325
6275
6326
// Transpose to xHWC
6276
6327
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6277
- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6328
+ transposePoolingInputToHwc (op, rewriter, self );
6278
6329
6279
6330
return success ();
6280
6331
}
@@ -6308,23 +6359,27 @@ class ConvertAtenAvgPool1dOp
6308
6359
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
6309
6360
SmallVector<int64_t > rank4Shape (selfShape);
6310
6361
rank4Shape.push_back (1 );
6311
- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6312
- op->getLoc (),
6313
- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6314
- selfTy.getElementType ()),
6315
- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6362
+ auto reshapedSelf =
6363
+ rewriter
6364
+ .create <tosa::ReshapeOp>(
6365
+ op->getLoc (),
6366
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6367
+ selfTy.getElementType ()),
6368
+ self,
6369
+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6370
+ .getResult ();
6316
6371
6317
6372
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6318
6373
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
6319
6374
tosa::AvgPool2dOp>(
6320
- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6321
- kernel, stride, pad)))
6375
+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6376
+ pad)))
6322
6377
return rewriter.notifyMatchFailure (
6323
6378
op, " invalid pooling parameters or input type" );
6324
6379
6325
6380
// Transpose to xHWC
6326
6381
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6327
- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6382
+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6328
6383
6329
6384
return success ();
6330
6385
}
0 commit comments