@@ -5828,19 +5828,69 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5828
5828
op, " Unimplemented pooling input parsing function" );
5829
5829
}
5830
5830
5831
- static int64_t getOutputDim (int64_t inputDim, int64_t kernelDim,
5832
- int64_t stride, int64_t padBefore,
5833
- int64_t padAfter, int64_t dilation,
5831
+ static int64_t getOutputDim (PatternRewriter &rewriter, Value &input,
5832
+ Location loc, int64_t inputRank,
5833
+ ArrayRef<int64_t > inputShape, Type inputElemTy,
5834
+ int64_t dimIndex, int64_t kernelDim,
5835
+ int64_t stride, int64_t &padBefore,
5836
+ int64_t &padAfter, int64_t dilation,
5834
5837
bool ceilMode = false ) {
5838
+ int64_t inputDim = inputShape[dimIndex];
5835
5839
if (inputDim == kUnknownSize ) {
5836
5840
return kUnknownSize ;
5837
5841
} else {
5842
+ // TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5843
+ // be fully divisible by stride. We would have to modify the after pad
5844
+ // and/ input in order to achieve that.
5845
+ // Note: The dimSize calculation below is the same as TOSA's dimSize
5846
+ // calculation when dilation = 1, which is the only dilation value that
5847
+ // TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5848
+ // value will be defaulted to 1)
5838
5849
int64_t dimSize =
5839
5850
inputDim + padBefore + padAfter - dilation * (kernelDim - 1 ) - 1 ;
5851
+ int64_t remainderDim = dimSize % stride;
5852
+
5853
+ // When PyTorch uses floor mode for output dim calculation, to achieve the
5854
+ // TOSA's divisibility requirement, we will remove the unused after pad
5855
+ // and slice the unused input rows/columns.
5856
+ if (!ceilMode && (remainderDim != 0 )) {
5857
+ if (remainderDim > padAfter) {
5858
+ SmallVector<int64_t > startSlice (inputRank, 0 );
5859
+ // In cases where we have to do 2 slice operations (one for height and
5860
+ // one for width), we need to use the new sliced shape before doing
5861
+ // the second slice, not the original inputShape. Therefore, the shape
5862
+ // needs to be retrieved again here.
5863
+ SmallVector<int64_t > sizeSlice (
5864
+ dyn_cast<TensorType>(input.getType ()).getShape ());
5865
+ sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5866
+ input = rewriter.create <tosa::SliceOp>(
5867
+ loc, RankedTensorType::get (sizeSlice, inputElemTy), input,
5868
+ tosa::getTosaConstShape (rewriter, loc, startSlice),
5869
+ tosa::getTosaConstShape (rewriter, loc, sizeSlice));
5870
+ dimSize = dimSize - padAfter;
5871
+ padAfter = 0 ;
5872
+ } else {
5873
+ dimSize = dimSize - padAfter;
5874
+ padAfter = padAfter - remainderDim;
5875
+ dimSize = dimSize + padAfter;
5876
+ }
5877
+ }
5878
+
5840
5879
int64_t outputDim = dimSize / stride + 1 ;
5841
- if (ceilMode && (dimSize % stride != 0 ) &&
5842
- (outputDim * stride < inputDim + padBefore))
5843
- outputDim++;
5880
+
5881
+ // When PyTorch uses ceil mode for output dim calculation, to achieve the
5882
+ // TOSA's divisibility requirement, we will remove the unused after pad
5883
+ // or add more after pad in case the remainder is more than the after pad
5884
+ if (ceilMode && (remainderDim != 0 )) {
5885
+ if (remainderDim < padAfter) {
5886
+ padAfter = padAfter - remainderDim;
5887
+ } else {
5888
+ padAfter = padAfter + (stride - remainderDim);
5889
+ }
5890
+
5891
+ if (outputDim * stride < inputDim + padBefore)
5892
+ outputDim++;
5893
+ }
5844
5894
return outputDim;
5845
5895
}
5846
5896
}
@@ -6078,25 +6128,24 @@ class ConvertAtenAdaptivePoolingOp
6078
6128
6079
6129
template <typename AtenOpT, typename tosaOp>
6080
6130
static Type getOutputTypeForNonAdaptivePoolingOp (
6131
+ PatternRewriter &rewriter, Operation *op, Value &input,
6081
6132
RankedTensorType inputTy, SmallVectorImpl<int64_t > &kernelSize,
6082
6133
SmallVectorImpl<int64_t > &strideArray, SmallVectorImpl<int64_t > &padArray,
6083
6134
SmallVectorImpl<int64_t > &dilationArray, bool ceilMode = false ) {
6084
6135
auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
6085
6136
auto inputRank = inputTy.getRank ();
6086
6137
auto inputElemTy = inputTy.getElementType ();
6087
6138
6139
+ // PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6140
+ // rank-1
6088
6141
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6089
- inputShape[inputRank - 2 ], kernelSize[0 ], strideArray[0 ], padArray[0 ],
6090
- padArray[0 ], dilationArray[0 ], ceilMode);
6142
+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6143
+ /* dimIndex=*/ inputRank - 2 , kernelSize[0 ], strideArray[0 ], padArray[0 ],
6144
+ padArray[1 ], dilationArray[0 ], ceilMode);
6091
6145
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6092
- inputShape[inputRank - 1 ], kernelSize[1 ], strideArray[1 ], padArray[1 ],
6093
- padArray[1 ], dilationArray[1 ], ceilMode);
6094
- padArray[0 ] = (outputHDim - 1 ) * strideArray[0 ] +
6095
- dilationArray[0 ] * kernelSize[0 ] - dilationArray[0 ] + 1 -
6096
- padArray[0 ] * 2 - inputShape[inputRank - 2 ];
6097
- padArray[1 ] = (outputWDim - 1 ) * strideArray[1 ] +
6098
- dilationArray[0 ] * kernelSize[1 ] - dilationArray[0 ] + 1 -
6099
- padArray[1 ] * 2 - inputShape[inputRank - 1 ];
6146
+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6147
+ /* dimIndex=*/ inputRank - 1 , kernelSize[1 ], strideArray[1 ], padArray[2 ],
6148
+ padArray[3 ], dilationArray[1 ], ceilMode);
6100
6149
SmallVector<int64_t > outputShape;
6101
6150
if (inputRank > 3 )
6102
6151
outputShape.push_back (inputShape[0 ]);
@@ -6127,7 +6176,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> ¶ms,
6127
6176
// vector. Also, gets the output type for the pooling op.
6128
6177
template <typename AtenOpT, typename tosaOp>
6129
6178
static LogicalResult getOutputTypeAndPoolingParameters (
6130
- AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6179
+ AtenOpT op, ConversionPatternRewriter &rewriter, Value & inputXchw,
6131
6180
SmallVectorImpl<int64_t > &dilationArray, Type &outputTy,
6132
6181
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6133
6182
DenseI64ArrayAttr &pad) {
@@ -6200,10 +6249,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
6200
6249
6201
6250
expandPoolParams (op, dilationArray, 1 );
6202
6251
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6203
- inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6204
- ceilMode);
6205
- padArr[1 ] = padArr[1 ] + paddingInts[0 ];
6206
- padArr[3 ] = padArr[3 ] + paddingInts[1 ];
6252
+ rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6253
+ dilationArray, ceilMode);
6207
6254
pad = rewriter.getDenseI64ArrayAttr (
6208
6255
{padArr[0 ], padArr[1 ], padArr[2 ], padArr[3 ]});
6209
6256
return success ();
@@ -6219,6 +6266,7 @@ class ConvertAtenMaxPool2dOp
6219
6266
DenseI64ArrayAttr &kernel,
6220
6267
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
6221
6268
Type &outputTy) const override {
6269
+ auto self = adaptor.getSelf ();
6222
6270
SmallVector<int64_t , 2 > dilationArray;
6223
6271
if (!matchPattern (op.getDilation (),
6224
6272
m_TorchListOfConstantInts (dilationArray)))
@@ -6231,14 +6279,13 @@ class ConvertAtenMaxPool2dOp
6231
6279
6232
6280
if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
6233
6281
tosa::MaxPool2dOp>(
6234
- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6235
- stride, pad)))
6282
+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6236
6283
return rewriter.notifyMatchFailure (
6237
6284
op, " invalid pooling parameters or input type" );
6238
6285
6239
6286
// Transpose to xHWC
6240
6287
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6241
- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6288
+ transposePoolingInputToHwc (op, rewriter, self );
6242
6289
6243
6290
return success ();
6244
6291
}
@@ -6272,11 +6319,15 @@ class ConvertAtenMaxPool1dOp
6272
6319
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
6273
6320
SmallVector<int64_t > rank4Shape (selfShape);
6274
6321
rank4Shape.push_back (1 );
6275
- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6276
- op->getLoc (),
6277
- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6278
- selfTy.getElementType ()),
6279
- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6322
+ auto reshapedSelf =
6323
+ rewriter
6324
+ .create <tosa::ReshapeOp>(
6325
+ op->getLoc (),
6326
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6327
+ selfTy.getElementType ()),
6328
+ self,
6329
+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6330
+ .getResult ();
6280
6331
6281
6332
SmallVector<int64_t > dilationArray;
6282
6333
if (!matchPattern (op.getDilation (),
@@ -6293,14 +6344,14 @@ class ConvertAtenMaxPool1dOp
6293
6344
6294
6345
if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
6295
6346
tosa::MaxPool2dOp>(
6296
- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6297
- kernel, stride, pad)))
6347
+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6348
+ pad)))
6298
6349
return rewriter.notifyMatchFailure (
6299
6350
op, " invalid pooling parameters or input type" );
6300
6351
6301
6352
// Transpose to xHWC
6302
6353
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6303
- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6354
+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6304
6355
6305
6356
return success ();
6306
6357
}
@@ -6316,6 +6367,7 @@ class ConvertAtenAvgPool2dOp
6316
6367
DenseI64ArrayAttr &kernel,
6317
6368
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
6318
6369
Type &outputTy) const override {
6370
+ auto self = adaptor.getSelf ();
6319
6371
6320
6372
// Currently, we can not represent `divisor_override` with the existing TOSA
6321
6373
// AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6329,14 +6381,13 @@ class ConvertAtenAvgPool2dOp
6329
6381
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6330
6382
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
6331
6383
tosa::AvgPool2dOp>(
6332
- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6333
- stride, pad)))
6384
+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6334
6385
return rewriter.notifyMatchFailure (
6335
6386
op, " invalid pooling parameters or input type" );
6336
6387
6337
6388
// Transpose to xHWC
6338
6389
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6339
- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6390
+ transposePoolingInputToHwc (op, rewriter, self );
6340
6391
6341
6392
return success ();
6342
6393
}
@@ -6370,23 +6421,27 @@ class ConvertAtenAvgPool1dOp
6370
6421
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
6371
6422
SmallVector<int64_t > rank4Shape (selfShape);
6372
6423
rank4Shape.push_back (1 );
6373
- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6374
- op->getLoc (),
6375
- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6376
- selfTy.getElementType ()),
6377
- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6424
+ auto reshapedSelf =
6425
+ rewriter
6426
+ .create <tosa::ReshapeOp>(
6427
+ op->getLoc (),
6428
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6429
+ selfTy.getElementType ()),
6430
+ self,
6431
+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6432
+ .getResult ();
6378
6433
6379
6434
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6380
6435
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
6381
6436
tosa::AvgPool2dOp>(
6382
- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6383
- kernel, stride, pad)))
6437
+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6438
+ pad)))
6384
6439
return rewriter.notifyMatchFailure (
6385
6440
op, " invalid pooling parameters or input type" );
6386
6441
6387
6442
// Transpose to xHWC
6388
6443
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6389
- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6444
+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6390
6445
6391
6446
return success ();
6392
6447
}
0 commit comments