@@ -3757,6 +3757,91 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
3757
3757
return %5 : !torch.vtensor <[1 ,32 ,75 ,75 ],f32 >
3758
3758
}
3759
3759
3760
+
3761
+ // -----
3762
+
3763
+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3764
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3765
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3766
+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3767
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3768
+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3769
+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3770
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3771
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3772
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3773
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3774
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3775
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3776
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3777
+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3778
+ // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3779
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3780
+ // CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3781
+ // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3782
+ // CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3783
+ // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3784
+ // CHECK: return %[[VAL_19]]
3785
+
3786
+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,224 ,224 ],f32 >) -> !torch.vtensor <[?,32 ,112 ,112 ],f32 > {
3787
+ %false = torch.constant.bool false
3788
+ %int1 = torch.constant.int 1
3789
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3790
+ %none = torch.constant.none
3791
+ %int2 = torch.constant.int 2
3792
+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
3793
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3794
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3795
+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3796
+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,224 ,224 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3797
+ return %5 : !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3798
+ }
3799
+
3800
+
3801
+ // -----
3802
+
3803
+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3804
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3805
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3806
+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3807
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3808
+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3809
+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3810
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3811
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3812
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3813
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3814
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3815
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3816
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3817
+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3818
+ // CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3819
+ // CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3820
+ // CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3821
+ // CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3822
+ // CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3823
+ // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3824
+ // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3825
+ // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3826
+ // CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3827
+ // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3828
+ // CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3829
+ // CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3830
+ // CHECK: return %[[VAL_25]]
3831
+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,225 ,225 ],f32 >) -> !torch.vtensor <[?,32 ,75 ,75 ],f32 > {
3832
+ %false = torch.constant.bool false
3833
+ %int1 = torch.constant.int 1
3834
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3835
+ %none = torch.constant.none
3836
+ %int3 = torch.constant.int 3
3837
+ %1 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
3838
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3839
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3840
+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3841
+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,225 ,225 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3842
+ return %5 : !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3843
+ }
3844
+
3760
3845
// -----
3761
3846
3762
3847
// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(
0 commit comments