Skip to content

Commit 1d8cc45

Browse files
committed
[mlir][vector] Add patterns to convert multidimreduce to vector.contract
add several patterns that will simplify contraction vectorization in the future. With those canonicalizationns we will be able to remove the special case for contration during vectorization and rely on those transformations to avoid materizalizing broadcast ops. Differential Revision: https://reviews.llvm.org/D112121
1 parent 5dc339d commit 1d8cc45

File tree

4 files changed

+310
-0
lines changed

4 files changed

+310
-0
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ void populateVectorTransposeLoweringPatterns(
215215
RewritePatternSet &patterns,
216216
VectorTransformsOptions options = VectorTransformsOptions());
217217

218+
/// Collect patterns to convert reduction op to vector.contract and fold
219+
/// transpose/broadcast ops into the contract.
220+
void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
221+
218222
/// Returns the integer type required for subscripts in the vector dialect.
219223
IntegerType getVectorSubscriptType(Builder &builder);
220224

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
240240
return slicedIndices;
241241
}
242242

243+
template <typename IntType>
244+
static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
245+
return llvm::to_vector<4>(llvm::map_range(
246+
arrayAttr.getAsRange<IntegerAttr>(),
247+
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
248+
}
249+
243250
namespace {
244251

245252
struct UnrollTransferReadPattern
@@ -1114,6 +1121,193 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
11141121
}
11151122
};
11161123

1124+
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
1125+
/// Ex:
1126+
/// ```
1127+
/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
1128+
/// %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
1129+
/// : vector<8x32x16xf32> to vector<8x16xf32>
1130+
/// ```
1131+
/// Gets converted to:
1132+
/// ```
1133+
/// %1 = vector.contract {indexing_maps = [
1134+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1135+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1136+
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
1137+
/// iterator_types = ["parallel", "parallel", "reduction"],
1138+
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
1139+
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1140+
/// ```
1141+
struct MultiReduceToContract
1142+
: public OpRewritePattern<vector::MultiDimReductionOp> {
1143+
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
1144+
1145+
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
1146+
PatternRewriter &rewriter) const override {
1147+
if (reduceOp.kind() != vector::CombiningKind::ADD)
1148+
return failure();
1149+
Operation *mulOp = reduceOp.source().getDefiningOp();
1150+
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
1151+
return failure();
1152+
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
1153+
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
1154+
SmallVector<AffineExpr> exprs;
1155+
SmallVector<StringRef> iteratorTypes;
1156+
for (auto isReduceDim : llvm::enumerate(reductionMask)) {
1157+
if (!isReduceDim.value()) {
1158+
iteratorTypes.push_back(getParallelIteratorTypeName());
1159+
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
1160+
} else {
1161+
iteratorTypes.push_back(getReductionIteratorTypeName());
1162+
}
1163+
}
1164+
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
1165+
/*symCount=*/0, exprs, reduceOp.getContext());
1166+
Value zero = rewriter.create<arith::ConstantOp>(
1167+
reduceOp.getLoc(), reduceOp.getDestType(),
1168+
rewriter.getZeroAttr(reduceOp.getDestType()));
1169+
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
1170+
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
1171+
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
1172+
rewriter.getStrArrayAttr(iteratorTypes));
1173+
return success();
1174+
}
1175+
};
1176+
1177+
/// Merge TransposeOp into ContractionOp user.
1178+
/// Ex:
1179+
/// ```
1180+
/// %0 = vector.transpose %arg0, [2, 0, 1]
1181+
/// : vector<32x16x8xf32> to vector<8x32x16xf32>
1182+
/// %1 = vector.contract {indexing_maps = [
1183+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1184+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1185+
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
1186+
/// iterator_types = ["parallel", "parallel", "reduction"],
1187+
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
1188+
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1189+
/// ```
1190+
/// Gets converted to:
1191+
/// ```
1192+
/// %1 = vector.contract {indexing_maps = [
1193+
/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
1194+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1195+
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
1196+
/// iterator_types = ["parallel", "parallel", "reduction"],
1197+
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
1198+
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1199+
/// ```
1200+
struct CombineContractTranspose
1201+
: public OpRewritePattern<vector::ContractionOp> {
1202+
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1203+
1204+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1205+
PatternRewriter &rewriter) const override {
1206+
SmallVector<AffineMap, 4> maps =
1207+
llvm::to_vector<4>(contractOp.getIndexingMaps());
1208+
Value lhs = contractOp.lhs();
1209+
Value rhs = contractOp.rhs();
1210+
size_t index = 0;
1211+
bool changed = false;
1212+
for (Value *operand : {&lhs, &rhs}) {
1213+
AffineMap &map = maps[index++];
1214+
auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
1215+
if (!transposeOp)
1216+
continue;
1217+
SmallVector<int64_t> perm;
1218+
transposeOp.getTransp(perm);
1219+
AffineMap permutationMap = AffineMap::getPermutationMap(
1220+
extractVector<unsigned>(transposeOp.transp()),
1221+
contractOp.getContext());
1222+
map = inversePermutation(permutationMap).compose(map);
1223+
*operand = transposeOp.vector();
1224+
changed = true;
1225+
}
1226+
if (!changed)
1227+
return failure();
1228+
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1229+
contractOp, lhs, rhs, contractOp.acc(),
1230+
rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1231+
return success();
1232+
}
1233+
};
1234+
1235+
/// Merge BroadcastOp into ContractionOp user.
1236+
/// Ex:
1237+
/// ```
1238+
/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
1239+
/// %1 = vector.contract {indexing_maps = [
1240+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1241+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1242+
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
1243+
/// iterator_types = ["parallel", "parallel", "reduction"],
1244+
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
1245+
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1246+
/// ```
1247+
/// Gets converted to:
1248+
/// ```
1249+
/// %1 = vector.contract {indexing_maps = [
1250+
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
1251+
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1252+
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
1253+
/// iterator_types = ["parallel", "parallel", "reduction"],
1254+
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
1255+
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1256+
/// ```
1257+
struct CombineContractBroadcast
1258+
: public OpRewritePattern<vector::ContractionOp> {
1259+
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1260+
1261+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1262+
PatternRewriter &rewriter) const override {
1263+
SmallVector<AffineMap, 4> maps =
1264+
llvm::to_vector<4>(contractOp.getIndexingMaps());
1265+
Value lhs = contractOp.lhs();
1266+
Value rhs = contractOp.rhs();
1267+
size_t index = 0;
1268+
bool changed = false;
1269+
for (Value *operand : {&lhs, &rhs}) {
1270+
AffineMap &map = maps[index++];
1271+
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
1272+
if (!broadcast)
1273+
continue;
1274+
// contractionOp can only take vector as operands.
1275+
auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
1276+
if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
1277+
continue;
1278+
int64_t rankDiff =
1279+
broadcast.getVectorType().getRank() - srcType.getRank();
1280+
bool innerDimBroadcast = false;
1281+
SmallVector<AffineExpr> originalDims;
1282+
for (auto dim : llvm::enumerate(srcType.getShape())) {
1283+
if (dim.value() !=
1284+
broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1285+
innerDimBroadcast = true;
1286+
break;
1287+
}
1288+
originalDims.push_back(
1289+
rewriter.getAffineDimExpr(dim.index() + rankDiff));
1290+
}
1291+
// Contract doesn't support inner dimension broadcast. Once this is
1292+
// relaxed we can remove this case.
1293+
if (innerDimBroadcast)
1294+
continue;
1295+
AffineMap broadcastMap =
1296+
AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1297+
contractOp.getContext());
1298+
map = broadcastMap.compose(map);
1299+
*operand = broadcast.source();
1300+
changed = true;
1301+
}
1302+
if (!changed)
1303+
return failure();
1304+
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1305+
contractOp, lhs, rhs, contractOp.acc(),
1306+
rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1307+
return success();
1308+
}
1309+
};
1310+
11171311
} // namespace
11181312

11191313
/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@@ -3668,6 +3862,12 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
36683862
patterns.add<TransposeOpLowering>(options, patterns.getContext());
36693863
}
36703864

3865+
void mlir::vector::populateVetorReductionToContractPatterns(
3866+
RewritePatternSet &patterns) {
3867+
patterns.add<MultiReduceToContract, CombineContractBroadcast,
3868+
CombineContractTranspose>(patterns.getContext());
3869+
}
3870+
36713871
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
36723872
RewritePatternSet &patterns) {
36733873
patterns.add<TransferReadPermutationLowering,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
2+
3+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
5+
6+
// CHECK-LABEL: multidimreduction_contract
7+
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
8+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
9+
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
10+
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
11+
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
12+
func @multidimreduction_contract(
13+
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
14+
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
15+
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
16+
return %1 : vector<8x16xf32>
17+
}
18+
19+
// -----
20+
21+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
22+
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
23+
24+
// CHECK-LABEL: multidimreduction_contract_int
25+
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
26+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
27+
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
28+
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
29+
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
30+
func @multidimreduction_contract_int(
31+
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
32+
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
33+
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
34+
return %1 : vector<8x16xi32>
35+
}
36+
37+
// -----
38+
39+
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
40+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
41+
42+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
43+
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
44+
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
45+
46+
// CHECK-LABEL: contract_transpose
47+
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
48+
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
49+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
50+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
51+
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
52+
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
53+
func @contract_transpose(
54+
%arg0: vector<32x16x8xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
55+
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
56+
%0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32>
57+
%1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
58+
iterator_types = ["parallel", "parallel", "reduction"],
59+
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
60+
return %1 : vector<8x32xf32>
61+
}
62+
63+
// -----
64+
65+
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
66+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
67+
68+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
69+
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
70+
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
71+
72+
// CHECK-LABEL: contract_broadcast
73+
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
74+
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
75+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
76+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
77+
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78+
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
79+
func @contract_broadcast(
80+
%arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
81+
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
82+
%0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
83+
%1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
84+
iterator_types = ["parallel", "parallel", "reduction"],
85+
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
86+
return %1 : vector<8x32xf32>
87+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,23 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
493493
}
494494
};
495495

496+
struct TestVectorReduceToContractPatternsPatterns
497+
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
498+
FunctionPass> {
499+
StringRef getArgument() const final {
500+
return "test-vector-reduction-to-contract-patterns";
501+
}
502+
StringRef getDescription() const final {
503+
return "Test patterns to convert multireduce op to contract and combine "
504+
"broadcast/transpose to contract";
505+
}
506+
void runOnFunction() override {
507+
RewritePatternSet patterns(&getContext());
508+
populateVetorReductionToContractPatterns(patterns);
509+
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
510+
}
511+
};
512+
496513
} // end anonymous namespace
497514

498515
namespace mlir {
@@ -519,6 +536,8 @@ void registerTestVectorConversions() {
519536
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
520537

521538
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
539+
540+
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
522541
}
523542
} // namespace test
524543
} // namespace mlir

0 commit comments

Comments
 (0)