Skip to content

Commit 234193b

Browse files
[mlir][linalg] Vectorization support for convolution of i1 type (#109480)
Normally convolutions present with the following linalg op region ``` ^bb0(%arg14: i4, %arg15: i4, %arg16: i4): %17 = arith.muli %arg14, %arg15 : i4 %18 = arith.addi %arg16, %17 : i4 linalg.yield %18 : i4 ``` However, for i1 due to strength reduction we get something like ``` ^bb0(%arg14: i1, %arg15: i1, %arg16: i1): %17 = arith.andi %arg14, %arg15 : i1 %18 = arith.ori %arg16, %17 : i1 linalg.yield %18 : i1 ``` This PR updates the logic to support this region for i1 types.
1 parent f404207 commit 234193b

File tree

2 files changed

+81
-4
lines changed

2 files changed

+81
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2987,10 +2987,15 @@ struct Conv1DGenerator
29872987
if (!setOperKind(reduceOp))
29882988
return;
29892989
auto maybeKind = getCombinerOpKind(reduceOp);
2990-
if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2990+
// Typically convolution will have a `Add` CombiningKind but for i1 type it
2991+
// can get strength reduced to `OR` which is also supported. This strength
2992+
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2993+
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2994+
*maybeKind != vector::CombiningKind::OR) &&
29912995
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
29922996
return;
29932997
}
2998+
reductionKind = maybeKind.value();
29942999

29953000
auto rhsRank = rhsShapedType.getRank();
29963001
switch (oper) {
@@ -3273,10 +3278,12 @@ struct Conv1DGenerator
32733278
bindDims(ctx, n, w, f, c);
32743279
lhs = promote(rewriter, loc, lhs, res.getType());
32753280
rhs = promote(rewriter, loc, rhs, res.getType());
3276-
return rewriter.create<vector::ContractionOp>(
3281+
auto contrationOp = rewriter.create<vector::ContractionOp>(
32773282
loc, lhs, rhs, res,
32783283
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
32793284
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3285+
contrationOp.setKind(reductionKind);
3286+
return contrationOp;
32803287
}
32813288

32823289
// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3666,6 +3673,7 @@ struct Conv1DGenerator
36663673
int strideW, dilationW;
36673674
Value lhsShaped, rhsShaped, resShaped;
36683675
ShapedType lhsShapedType, rhsShapedType, resShapedType;
3676+
vector::CombiningKind reductionKind;
36693677

36703678
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
36713679
// Returns true iff it is a valid conv/pooling op.
@@ -3681,15 +3689,19 @@ struct Conv1DGenerator
36813689
switch (numBlockArguments) {
36823690
case 1: {
36833691
// Will be convolution if feeder is a MulOp.
3684-
// Otherwise, if it can be pooling.
3692+
// A strength reduced version of MulOp for i1 type is AndOp which is also
3693+
// supported. Otherwise, it can be pooling. This strength reduction logic
3694+
// is in `buildBinaryFn` helper in the Linalg dialect.
36853695
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
36863696
llvm::IsaPred<BlockArgument>);
36873697
Operation *feedOp = (*feedValIt).getDefiningOp();
36883698
if (isCastOfBlockArgument(feedOp)) {
36893699
oper = Pool;
36903700
isPoolExt = true;
36913701
poolExtOp = feedOp->getName().getIdentifier();
3692-
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3702+
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3703+
(isa<arith::AndIOp>(feedOp) &&
3704+
feedOp->getResultTypes()[0].isInteger(1))) &&
36933705
llvm::all_of(feedOp->getOperands(), [](Value v) {
36943706
if (isa<BlockArgument>(v))
36953707
return true;

mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
3939
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
4040
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
4141
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
42+
// CHECK-SAME: kind = #vector.kind<add>
4243
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
4344
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
4445

4546
/// w == 1, kw == 0
4647
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
4748
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
4849
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
50+
// CHECK-SAME: kind = #vector.kind<add>
4951
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
5052
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
5153

@@ -61,6 +63,36 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
6163

6264
// -----
6365

66+
// This test is same as above but for i1 type with the only difference being that
67+
// the combining kind for `vector.contract` is `OR`.
68+
func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<1x3x8xi1>, %output: memref<4x2x8xi1>) {
69+
linalg.conv_1d_nwc_wcf
70+
{dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
71+
ins(%input, %filter : memref<4x6x3xi1>, memref<1x3x8xi1>)
72+
outs(%output : memref<4x2x8xi1>)
73+
return
74+
}
75+
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
76+
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
77+
// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
78+
79+
// CHECK: func @conv1d_nwc_4x2x8_memref_i1
80+
/// w == 0, kw == 0
81+
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
82+
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
83+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
84+
// CHECK-SAME: kind = #vector.kind<or>
85+
// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
86+
87+
/// w == 1, kw == 0
88+
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
89+
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
90+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
91+
// CHECK-SAME: kind = #vector.kind<or>
92+
// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
93+
94+
// -----
95+
6496
// The i8i8i32 case is similar to f32 case, so checking one case is enough for
6597
// test coverage.
6698
func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) {
@@ -299,13 +331,15 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
299331
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
300332
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
301333
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
334+
// CHECK-SAME: kind = #vector.kind<add>
302335
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
303336
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
304337

305338
/// w == 1, kw == 0
306339
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
307340
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
308341
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
342+
// CHECK-SAME: kind = #vector.kind<add>
309343
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
310344
// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
311345

@@ -324,6 +358,37 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
324358

325359
// -----
326360

361+
// This test is same as above but for i1 type with the only difference being that
362+
// the combining kind for `vector.contract` is `OR`.
363+
func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<8x3x1xi1>, %output: memref<4x8x2xi1>) {
364+
linalg.conv_1d_ncw_fcw
365+
{dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
366+
ins(%input, %filter : memref<4x3x6xi1>, memref<8x3x1xi1>)
367+
outs(%output : memref<4x8x2xi1>)
368+
return
369+
}
370+
371+
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
372+
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
373+
// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
374+
375+
// CHECK: func @conv1d_ncw_4x8x2_memref_i1
376+
/// w == 0, kw == 0
377+
// CHECK: vector.contract {
378+
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
379+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
380+
// CHECK-SAME: kind = #vector.kind<or>
381+
// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
382+
383+
/// w == 1, kw == 0
384+
// CHECK: vector.contract {
385+
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
386+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
387+
// CHECK-SAME: kind = #vector.kind<or>
388+
// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
389+
390+
// -----
391+
327392
func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
328393
linalg.conv_1d_ncw_fcw
329394
{dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}

0 commit comments

Comments
 (0)