@@ -3757,6 +3757,91 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
37573757 return %5 : !torch.vtensor <[1 ,32 ,75 ,75 ],f32 >
37583758}
37593759
3760+
3761+ // -----
3762+
3763+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3764+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3765+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3766+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3767+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3768+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3769+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3770+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3771+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3772+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3773+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3774+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3775+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3776+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3777+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3778+ // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3779+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3780+ // CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3781+ // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3782+ // CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3783+ // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3784+ // CHECK: return %[[VAL_19]]
3785+
3786+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,224 ,224 ],f32 >) -> !torch.vtensor <[?,32 ,112 ,112 ],f32 > {
3787+ %false = torch.constant.bool false
3788+ %int1 = torch.constant.int 1
3789+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3790+ %none = torch.constant.none
3791+ %int2 = torch.constant.int 2
3792+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
3793+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3794+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3795+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3796+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,224 ,224 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3797+ return %5 : !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3798+ }
3799+
3800+
3801+ // -----
3802+
3803+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3804+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3805+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3806+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3807+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3808+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3809+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3810+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3811+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3812+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3813+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3814+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3815+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3816+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3817+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3818+ // CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3819+ // CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3820+ // CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3821+ // CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3822+ // CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3823+ // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3824+ // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3825+ // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3826+ // CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3827+ // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3828+ // CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3829+ // CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3830+ // CHECK: return %[[VAL_25]]
3831+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,225 ,225 ],f32 >) -> !torch.vtensor <[?,32 ,75 ,75 ],f32 > {
3832+ %false = torch.constant.bool false
3833+ %int1 = torch.constant.int 1
3834+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3835+ %none = torch.constant.none
3836+ %int3 = torch.constant.int 3
3837+ %1 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
3838+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3839+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3840+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3841+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,225 ,225 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3842+ return %5 : !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3843+ }
3844+
37603845// -----
37613846
37623847// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(
0 commit comments