@@ -240,6 +240,13 @@ sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
240
240
return slicedIndices;
241
241
}
242
242
243
+ template <typename IntType>
244
+ static SmallVector<IntType, 4 > extractVector (ArrayAttr arrayAttr) {
245
+ return llvm::to_vector<4 >(llvm::map_range (
246
+ arrayAttr.getAsRange <IntegerAttr>(),
247
+ [](IntegerAttr attr) { return static_cast <IntType>(attr.getInt ()); }));
248
+ }
249
+
243
250
namespace {
244
251
245
252
struct UnrollTransferReadPattern
@@ -1114,6 +1121,193 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
1114
1121
}
1115
1122
};
1116
1123
1124
+ // / Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
1125
+ // / Ex:
1126
+ // / ```
1127
+ // / %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
1128
+ // / %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
1129
+ // / : vector<8x32x16xf32> to vector<8x16xf32>
1130
+ // / ```
1131
+ // / Gets converted to:
1132
+ // / ```
1133
+ // / %1 = vector.contract {indexing_maps = [
1134
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1135
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1136
+ // / affine_map<(d0, d1, d2) -> (d0, d1)>],
1137
+ // / iterator_types = ["parallel", "parallel", "reduction"],
1138
+ // / kind = #vector.kind<add>} %0, %arg1, %cst_f0
1139
+ // / : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1140
+ // / ```
1141
+ struct MultiReduceToContract
1142
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
1143
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
1144
+
1145
+ LogicalResult matchAndRewrite (vector::MultiDimReductionOp reduceOp,
1146
+ PatternRewriter &rewriter) const override {
1147
+ if (reduceOp.kind () != vector::CombiningKind::ADD)
1148
+ return failure ();
1149
+ Operation *mulOp = reduceOp.source ().getDefiningOp ();
1150
+ if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
1151
+ return failure ();
1152
+ SmallVector<bool > reductionMask = reduceOp.getReductionMask ();
1153
+ auto srcMap = rewriter.getMultiDimIdentityMap (reductionMask.size ());
1154
+ SmallVector<AffineExpr> exprs;
1155
+ SmallVector<StringRef> iteratorTypes;
1156
+ for (auto isReduceDim : llvm::enumerate (reductionMask)) {
1157
+ if (!isReduceDim.value ()) {
1158
+ iteratorTypes.push_back (getParallelIteratorTypeName ());
1159
+ exprs.push_back (rewriter.getAffineDimExpr (isReduceDim.index ()));
1160
+ } else {
1161
+ iteratorTypes.push_back (getReductionIteratorTypeName ());
1162
+ }
1163
+ }
1164
+ auto dstMap = AffineMap::get (/* dimCount=*/ reductionMask.size (),
1165
+ /* symCount=*/ 0 , exprs, reduceOp.getContext ());
1166
+ Value zero = rewriter.create <arith::ConstantOp>(
1167
+ reduceOp.getLoc (), reduceOp.getDestType (),
1168
+ rewriter.getZeroAttr (reduceOp.getDestType ()));
1169
+ rewriter.replaceOpWithNewOp <mlir::vector::ContractionOp>(
1170
+ reduceOp, mulOp->getOperand (0 ), mulOp->getOperand (1 ), zero,
1171
+ rewriter.getAffineMapArrayAttr ({srcMap, srcMap, dstMap}),
1172
+ rewriter.getStrArrayAttr (iteratorTypes));
1173
+ return success ();
1174
+ }
1175
+ };
1176
+
1177
+ // / Merge TransposeOp into ContractionOp user.
1178
+ // / Ex:
1179
+ // / ```
1180
+ // / %0 = vector.transpose %arg0, [2, 0, 1]
1181
+ // / : vector<32x16x8xf32> to vector<8x32x16xf32>
1182
+ // / %1 = vector.contract {indexing_maps = [
1183
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1184
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1185
+ // / affine_map<(d0, d1, d2) -> (d0, d1)>],
1186
+ // / iterator_types = ["parallel", "parallel", "reduction"],
1187
+ // / kind = #vector.kind<add>} %0, %arg1, %cst_f0
1188
+ // / : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1189
+ // / ```
1190
+ // / Gets converted to:
1191
+ // / ```
1192
+ // / %1 = vector.contract {indexing_maps = [
1193
+ // / affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
1194
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1195
+ // / affine_map<(d0, d1, d2) -> (d0, d1)>],
1196
+ // / iterator_types = ["parallel", "parallel", "reduction"],
1197
+ // / kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
1198
+ // / : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1199
+ // / ```
1200
+ struct CombineContractTranspose
1201
+ : public OpRewritePattern<vector::ContractionOp> {
1202
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1203
+
1204
+ LogicalResult matchAndRewrite (vector::ContractionOp contractOp,
1205
+ PatternRewriter &rewriter) const override {
1206
+ SmallVector<AffineMap, 4 > maps =
1207
+ llvm::to_vector<4 >(contractOp.getIndexingMaps ());
1208
+ Value lhs = contractOp.lhs ();
1209
+ Value rhs = contractOp.rhs ();
1210
+ size_t index = 0 ;
1211
+ bool changed = false ;
1212
+ for (Value *operand : {&lhs, &rhs}) {
1213
+ AffineMap &map = maps[index++];
1214
+ auto transposeOp = operand->getDefiningOp <vector::TransposeOp>();
1215
+ if (!transposeOp)
1216
+ continue ;
1217
+ SmallVector<int64_t > perm;
1218
+ transposeOp.getTransp (perm);
1219
+ AffineMap permutationMap = AffineMap::getPermutationMap (
1220
+ extractVector<unsigned >(transposeOp.transp ()),
1221
+ contractOp.getContext ());
1222
+ map = inversePermutation (permutationMap).compose (map);
1223
+ *operand = transposeOp.vector ();
1224
+ changed = true ;
1225
+ }
1226
+ if (!changed)
1227
+ return failure ();
1228
+ rewriter.replaceOpWithNewOp <vector::ContractionOp>(
1229
+ contractOp, lhs, rhs, contractOp.acc (),
1230
+ rewriter.getAffineMapArrayAttr (maps), contractOp.iterator_types ());
1231
+ return success ();
1232
+ }
1233
+ };
1234
+
1235
+ // / Merge BroadcastOp into ContractionOp user.
1236
+ // / Ex:
1237
+ // / ```
1238
+ // / %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
1239
+ // / %1 = vector.contract {indexing_maps = [
1240
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1241
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1242
+ // / affine_map<(d0, d1, d2) -> (d0, d1)>],
1243
+ // / iterator_types = ["parallel", "parallel", "reduction"],
1244
+ // / kind = #vector.kind<add>} %0, %arg1, %cst_f0
1245
+ // / : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1246
+ // / ```
1247
+ // / Gets converted to:
1248
+ // / ```
1249
+ // / %1 = vector.contract {indexing_maps = [
1250
+ // / affine_map<(d0, d1, d2) -> (d1, d2)>,
1251
+ // / affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1252
+ // / affine_map<(d0, d1, d2) -> (d0, d1)>],
1253
+ // / iterator_types = ["parallel", "parallel", "reduction"],
1254
+ // / kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
1255
+ // / : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1256
+ // / ```
1257
+ struct CombineContractBroadcast
1258
+ : public OpRewritePattern<vector::ContractionOp> {
1259
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1260
+
1261
+ LogicalResult matchAndRewrite (vector::ContractionOp contractOp,
1262
+ PatternRewriter &rewriter) const override {
1263
+ SmallVector<AffineMap, 4 > maps =
1264
+ llvm::to_vector<4 >(contractOp.getIndexingMaps ());
1265
+ Value lhs = contractOp.lhs ();
1266
+ Value rhs = contractOp.rhs ();
1267
+ size_t index = 0 ;
1268
+ bool changed = false ;
1269
+ for (Value *operand : {&lhs, &rhs}) {
1270
+ AffineMap &map = maps[index++];
1271
+ auto broadcast = operand->getDefiningOp <vector::BroadcastOp>();
1272
+ if (!broadcast)
1273
+ continue ;
1274
+ // contractionOp can only take vector as operands.
1275
+ auto srcType = broadcast.getSourceType ().dyn_cast <VectorType>();
1276
+ if (!srcType || srcType.getRank () == broadcast.getVectorType ().getRank ())
1277
+ continue ;
1278
+ int64_t rankDiff =
1279
+ broadcast.getVectorType ().getRank () - srcType.getRank ();
1280
+ bool innerDimBroadcast = false ;
1281
+ SmallVector<AffineExpr> originalDims;
1282
+ for (auto dim : llvm::enumerate (srcType.getShape ())) {
1283
+ if (dim.value () !=
1284
+ broadcast.getVectorType ().getDimSize (rankDiff + dim.index ())) {
1285
+ innerDimBroadcast = true ;
1286
+ break ;
1287
+ }
1288
+ originalDims.push_back (
1289
+ rewriter.getAffineDimExpr (dim.index () + rankDiff));
1290
+ }
1291
+ // Contract doesn't support inner dimension broadcast. Once this is
1292
+ // relaxed we can remove this case.
1293
+ if (innerDimBroadcast)
1294
+ continue ;
1295
+ AffineMap broadcastMap =
1296
+ AffineMap::get (broadcast.getVectorType ().getRank (), 0 , originalDims,
1297
+ contractOp.getContext ());
1298
+ map = broadcastMap.compose (map);
1299
+ *operand = broadcast.source ();
1300
+ changed = true ;
1301
+ }
1302
+ if (!changed)
1303
+ return failure ();
1304
+ rewriter.replaceOpWithNewOp <vector::ContractionOp>(
1305
+ contractOp, lhs, rhs, contractOp.acc (),
1306
+ rewriter.getAffineMapArrayAttr (maps), contractOp.iterator_types ());
1307
+ return success ();
1308
+ }
1309
+ };
1310
+
1117
1311
} // namespace
1118
1312
1119
1313
// / Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@@ -3668,6 +3862,12 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
3668
3862
patterns.add <TransposeOpLowering>(options, patterns.getContext ());
3669
3863
}
3670
3864
3865
+ void mlir::vector::populateVetorReductionToContractPatterns (
3866
+ RewritePatternSet &patterns) {
3867
+ patterns.add <MultiReduceToContract, CombineContractBroadcast,
3868
+ CombineContractTranspose>(patterns.getContext ());
3869
+ }
3870
+
3671
3871
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns (
3672
3872
RewritePatternSet &patterns) {
3673
3873
patterns.add <TransferReadPermutationLowering,
0 commit comments