Skip to content

Commit 4d0e12c

Browse files
committed
[Tosa] : Slice conv inputs for dynamic batch as long as spatial dims are static.
1 parent 8e19615 commit 4d0e12c

File tree

2 files changed

+96
-7
lines changed

2 files changed

+96
-7
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,9 +2459,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24592459
}
24602460

24612461
int64_t outputHDim, outputWDim;
2462-
if (inputTy.hasStaticShape()) {
2463-
int64_t inputHDim = inputShape[2];
2464-
int64_t inputWDim = inputShape[3];
2462+
int64_t inputHDim = inputShape[2];
2463+
int64_t inputWDim = inputShape[3];
2464+
2465+
bool isStaticSpatialDims =
2466+
!ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim);
2467+
if (isStaticSpatialDims) {
2468+
24652469
int64_t weightHDim = weightShape[2];
24662470
int64_t weightWDim = weightShape[3];
24672471

@@ -2479,8 +2483,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24792483
SmallVector<int64_t> sizeHSlice(transposedInputShape);
24802484
// TOSA uses NHWC, so we will slice dim 1 for Height value
24812485
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
2482-
transposedInput = rewriter.create<tosa::SliceOp>(
2483-
op->getLoc(), RankedTensorType::get(sizeHSlice, inputElemTy),
2486+
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
2487+
rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy),
24842488
transposedInput,
24852489
tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice),
24862490
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice));
@@ -2504,8 +2508,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25042508
dyn_cast<RankedTensorType>(transposedInput.getType()).getShape());
25052509
// TOSA uses NHWC, so we will slice dim 2 for Width value
25062510
sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]);
2507-
transposedInput = rewriter.create<tosa::SliceOp>(
2508-
op->getLoc(), RankedTensorType::get(sizeWSlice, inputElemTy),
2511+
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
2512+
rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy),
25092513
transposedInput,
25102514
tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice),
25112515
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice));

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3862,4 +3862,89 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
38623862
return %5 : !torch.vtensor<[1,32,75,75],f32>
38633863
}
38643864

3865+
3866+
// -----
3867+
3868+
// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3869+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3870+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3871+
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3872+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3873+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3874+
// CHECK: %[[VAL_5:.*]] = torch.constant.none
3875+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3876+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3877+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3878+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3879+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3880+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3881+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3882+
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3883+
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3884+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3885+
// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3886+
// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3887+
// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3888+
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3889+
// CHECK: return %[[VAL_19]]
3890+
3891+
func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3892+
%false = torch.constant.bool false
3893+
%int1 = torch.constant.int 1
3894+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32>
3895+
%none = torch.constant.none
3896+
%int2 = torch.constant.int 2
3897+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
3898+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3899+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3900+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
3901+
%5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,32,112,112],f32>
3902+
return %5 : !torch.vtensor<[?,32,112,112],f32>
3903+
}
3904+
3905+
3906+
// -----
3907+
3908+
// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3909+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3910+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3911+
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3912+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3913+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3914+
// CHECK: %[[VAL_5:.*]] = torch.constant.none
3915+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3916+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3917+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3918+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3919+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3920+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3921+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3922+
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3923+
// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3924+
// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3925+
// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3926+
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3927+
// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3928+
// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3929+
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3930+
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3931+
// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3932+
// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3933+
// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3934+
// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3935+
// CHECK: return %[[VAL_25]]
3936+
func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3937+
%false = torch.constant.bool false
3938+
%int1 = torch.constant.int 1
3939+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32>
3940+
%none = torch.constant.none
3941+
%int3 = torch.constant.int 3
3942+
%1 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
3943+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3944+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3945+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
3946+
%5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,225,225],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,32,75,75],f32>
3947+
return %5 : !torch.vtensor<[?,32,75,75],f32>
3948+
}
3949+
38653950
// -----

0 commit comments

Comments
 (0)