Skip to content

Commit 40e12d6

Browse files
committed
[Tosa] : Slice conv inputs for dynamic batch as long as spatial dims are static.
1 parent e60c192 commit 40e12d6

File tree

2 files changed

+96
-7
lines changed

2 files changed

+96
-7
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -2453,9 +2453,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24532453
}
24542454

24552455
int64_t outputHDim, outputWDim;
2456-
if (inputTy.hasStaticShape()) {
2457-
int64_t inputHDim = inputShape[2];
2458-
int64_t inputWDim = inputShape[3];
2456+
int64_t inputHDim = inputShape[2];
2457+
int64_t inputWDim = inputShape[3];
2458+
2459+
bool isStaticSpatialDims =
2460+
!ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim);
2461+
if (isStaticSpatialDims) {
2462+
24592463
int64_t weightHDim = weightShape[2];
24602464
int64_t weightWDim = weightShape[3];
24612465

@@ -2473,8 +2477,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24732477
SmallVector<int64_t> sizeHSlice(transposedInputShape);
24742478
// TOSA uses NHWC, so we will slice dim 1 for Height value
24752479
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
2476-
transposedInput = rewriter.create<tosa::SliceOp>(
2477-
op->getLoc(), RankedTensorType::get(sizeHSlice, inputElemTy),
2480+
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
2481+
rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy),
24782482
transposedInput,
24792483
tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice),
24802484
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice));
@@ -2498,8 +2502,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24982502
dyn_cast<RankedTensorType>(transposedInput.getType()).getShape());
24992503
// TOSA uses NHWC, so we will slice dim 2 for Width value
25002504
sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]);
2501-
transposedInput = rewriter.create<tosa::SliceOp>(
2502-
op->getLoc(), RankedTensorType::get(sizeWSlice, inputElemTy),
2505+
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
2506+
rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy),
25032507
transposedInput,
25042508
tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice),
25052509
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice));

test/Conversion/TorchToTosa/basic.mlir

+85
Original file line numberDiff line numberDiff line change
@@ -3757,6 +3757,91 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
37573757
return %5 : !torch.vtensor<[1,32,75,75],f32>
37583758
}
37593759

3760+
3761+
// -----
3762+
3763+
// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3764+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3765+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3766+
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3767+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3768+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3769+
// CHECK: %[[VAL_5:.*]] = torch.constant.none
3770+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3771+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3772+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3773+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3774+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3775+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3776+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3777+
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3778+
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3779+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3780+
// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3781+
// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3782+
// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3783+
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3784+
// CHECK: return %[[VAL_19]]
3785+
3786+
func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3787+
%false = torch.constant.bool false
3788+
%int1 = torch.constant.int 1
3789+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32>
3790+
%none = torch.constant.none
3791+
%int2 = torch.constant.int 2
3792+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
3793+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3794+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3795+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
3796+
%5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,32,112,112],f32>
3797+
return %5 : !torch.vtensor<[?,32,112,112],f32>
3798+
}
3799+
3800+
3801+
// -----
3802+
3803+
// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3804+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3805+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3806+
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3807+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3808+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3809+
// CHECK: %[[VAL_5:.*]] = torch.constant.none
3810+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3811+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3812+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3813+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3814+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3815+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3816+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3817+
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3818+
// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3819+
// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3820+
// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3821+
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3822+
// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3823+
// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3824+
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3825+
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3826+
// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3827+
// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3828+
// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3829+
// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3830+
// CHECK: return %[[VAL_25]]
3831+
func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3832+
%false = torch.constant.bool false
3833+
%int1 = torch.constant.int 1
3834+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32>
3835+
%none = torch.constant.none
3836+
%int3 = torch.constant.int 3
3837+
%1 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
3838+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3839+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3840+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
3841+
%5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,225,225],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,32,75,75],f32>
3842+
return %5 : !torch.vtensor<[?,32,75,75],f32>
3843+
}
3844+
37603845
// -----
37613846

37623847
// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(

0 commit comments

Comments
 (0)