@@ -39,13 +39,15 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
39
39
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
40
40
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
41
41
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
42
+ // CHECK-SAME: kind = #vector.kind<add>
42
43
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
43
44
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
44
45
45
46
/// w == 1, kw == 0
46
47
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
47
48
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
48
49
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
50
+ // CHECK-SAME: kind = #vector.kind<add>
49
51
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
50
52
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
51
53
@@ -61,6 +63,36 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
61
63
62
64
// -----
63
65
66
+ // This test is same as above but for i1 type with the only difference being that
67
+ // the combining kind for `vector.contract` is `OR`.
68
+ func.func @conv1d_nwc_4x2x8_memref_i1 (%input: memref <4 x6 x3 xi1 >, %filter: memref <1 x3 x8 xi1 >, %output: memref <4 x2 x8 xi1 >) {
69
+ linalg.conv_1d_nwc_wcf
70
+ {dilations = dense <1 > : tensor <1 xi64 >, strides = dense <3 > : tensor <1 xi64 >}
71
+ ins (%input , %filter : memref <4 x6 x3 xi1 >, memref <1 x3 x8 xi1 >)
72
+ outs (%output : memref <4 x2 x8 xi1 >)
73
+ return
74
+ }
75
+ // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
76
+ // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
77
+ // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
78
+
79
+ // CHECK: func @conv1d_nwc_4x2x8_memref_i1
80
+ /// w == 0, kw == 0
81
+ // CHECK: %[[CONTRACT_0:.+]] = vector.contract {
82
+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
83
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
84
+ // CHECK-SAME: kind = #vector.kind<or>
85
+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
86
+
87
+ /// w == 1, kw == 0
88
+ // CHECK: %[[CONTRACT_1:.+]] = vector.contract {
89
+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
90
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
91
+ // CHECK-SAME: kind = #vector.kind<or>
92
+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
93
+
94
+ // -----
95
+
64
96
// The i8i8i32 case is similar to f32 case, so checking one case is enough for
65
97
// test coverage.
66
98
func.func @conv1d_nwc_4x2x8_i8i8i32_memref (%input: memref <4 x6 x3 xi8 >, %filter: memref <1 x3 x8 xi8 >, %output: memref <4 x2 x8 xi32 >) {
@@ -299,13 +331,15 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
299
331
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
300
332
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
301
333
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
334
+ // CHECK-SAME: kind = #vector.kind<add>
302
335
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
303
336
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
304
337
305
338
/// w == 1, kw == 0
306
339
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
307
340
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
308
341
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
342
+ // CHECK-SAME: kind = #vector.kind<add>
309
343
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
310
344
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
311
345
@@ -324,6 +358,37 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
324
358
325
359
// -----
326
360
361
+ // This test is same as above but for i1 type with the only difference being that
362
+ // the combining kind for `vector.contract` is `OR`.
363
+ func.func @conv1d_ncw_4x8x2_memref_i1 (%input: memref <4 x3 x6 xi1 >, %filter: memref <8 x3 x1 xi1 >, %output: memref <4 x8 x2 xi1 >) {
364
+ linalg.conv_1d_ncw_fcw
365
+ {dilations = dense <1 > : tensor <1 xi64 >, strides = dense <3 > : tensor <1 xi64 >}
366
+ ins (%input , %filter : memref <4 x3 x6 xi1 >, memref <8 x3 x1 xi1 >)
367
+ outs (%output : memref <4 x8 x2 xi1 >)
368
+ return
369
+ }
370
+
371
+ // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
372
+ // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
373
+ // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
374
+
375
+ // CHECK: func @conv1d_ncw_4x8x2_memref_i1
376
+ /// w == 0, kw == 0
377
+ // CHECK: vector.contract {
378
+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
379
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
380
+ // CHECK-SAME: kind = #vector.kind<or>
381
+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
382
+
383
+ /// w == 1, kw == 0
384
+ // CHECK: vector.contract {
385
+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
386
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
387
+ // CHECK-SAME: kind = #vector.kind<or>
388
+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
389
+
390
+ // -----
391
+
327
392
func.func @conv1d_ncw_4x8x2_memref (%input: memref <4 x3 x6 xf32 >, %filter: memref <8 x3 x2 xf32 >, %output: memref <4 x8 x2 xf32 >) {
328
393
linalg.conv_1d_ncw_fcw
329
394
{dilations = dense <2 > : tensor <1 xi64 >, strides = dense <3 > : tensor <1 xi64 >}
0 commit comments