@@ -3862,4 +3862,89 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
3862
3862
return %5 : !torch.vtensor <[1 ,32 ,75 ,75 ],f32 >
3863
3863
}
3864
3864
3865
+
3866
+ // -----
3867
+
3868
+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3869
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3870
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3871
+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3872
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3873
+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3874
+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3875
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3876
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3877
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3878
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3879
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3880
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3881
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3882
+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3883
+ // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3884
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3885
+ // CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3886
+ // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3887
+ // CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3888
+ // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3889
+ // CHECK: return %[[VAL_19]]
3890
+
3891
+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,224 ,224 ],f32 >) -> !torch.vtensor <[?,32 ,112 ,112 ],f32 > {
3892
+ %false = torch.constant.bool false
3893
+ %int1 = torch.constant.int 1
3894
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3895
+ %none = torch.constant.none
3896
+ %int2 = torch.constant.int 2
3897
+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
3898
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3899
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3900
+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3901
+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,224 ,224 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3902
+ return %5 : !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3903
+ }
3904
+
3905
+
3906
+ // -----
3907
+
3908
+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3909
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3910
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3911
+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3912
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3913
+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3914
+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3915
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3916
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3917
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3918
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3919
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3920
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3921
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3922
+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3923
+ // CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3924
+ // CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3925
+ // CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3926
+ // CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3927
+ // CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3928
+ // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3929
+ // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3930
+ // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3931
+ // CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3932
+ // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3933
+ // CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3934
+ // CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3935
+ // CHECK: return %[[VAL_25]]
3936
+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,225 ,225 ],f32 >) -> !torch.vtensor <[?,32 ,75 ,75 ],f32 > {
3937
+ %false = torch.constant.bool false
3938
+ %int1 = torch.constant.int 1
3939
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3940
+ %none = torch.constant.none
3941
+ %int3 = torch.constant.int 3
3942
+ %1 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
3943
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3944
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3945
+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3946
+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,225 ,225 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3947
+ return %5 : !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3948
+ }
3949
+
3865
3950
// -----
0 commit comments