Skip to content

Commit 93d0ade

Browse files
committed
[mlir][linalg] Remove special case for contraction vectorization
Handle contraction op like all the other generic op reductions. This simpifies the code. We now rely on contractionOp canonicalization to keep the same code quality. Differential Revision: https://reviews.llvm.org/D112171
1 parent 1d8cc45 commit 93d0ade

File tree

5 files changed

+60
-121
lines changed

5 files changed

+60
-121
lines changed

mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ struct LinalgStrategyVectorizePass
189189
vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
190190
filter, options);
191191
}
192+
vector::populateVectorTransferPermutationMapLoweringPatterns(
193+
vectorizationPatterns);
194+
vector::populateVetorReductionToContractPatterns(vectorizationPatterns);
192195
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
193196
linalg::LinalgCopyVTWForwardingPattern>(
194197
funcOp.getContext(), /*benefit=*/2);

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ using llvm::dbgs;
4545
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
4646
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
4747

48-
// Forward declarations.
49-
static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
50-
SmallVectorImpl<Value> &newResults);
5148
static FailureOr<Operation *>
5249
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
5350

@@ -495,10 +492,9 @@ static bool isElementwise(Operation *op) {
495492
/// the absence of good canonicalizations, the amount of work increases.
496493
/// This is not deemed a problem as we expect canonicalizations and foldings to
497494
/// aggressively clean up the useless work.
498-
LogicalResult vectorizeAsLinalgGeneric(
499-
OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
500-
bool broadcastToMaximalCommonShape = false,
501-
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
495+
static LogicalResult
496+
vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
497+
SmallVectorImpl<Value> &newResults) {
502498
Block *block = linalgOp.getBlock();
503499

504500
// 2. Values defined above the region can only be broadcast for now. Make them
@@ -530,8 +526,7 @@ LogicalResult vectorizeAsLinalgGeneric(
530526
if (linalgOp.getShape(opOperand).empty()) {
531527
readType = bbarg.getType();
532528
} else {
533-
if (broadcastToMaximalCommonShape &&
534-
opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
529+
if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
535530
map = inverseAndBroadcastProjectedPermuation(
536531
linalgOp.getTiedIndexingMap(opOperand));
537532
readType = VectorType::get(commonVectorShape,
@@ -549,7 +544,7 @@ LogicalResult vectorizeAsLinalgGeneric(
549544
bvm.map(opOperand->get(), readValue);
550545
}
551546

552-
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
547+
SmallVector<CustomVectorizationHook> hooks;
553548
// 4a. Register CustomVectorizationHook for yieldOp.
554549
CustomVectorizationHook vectorizeYield =
555550
[&](Operation *op,
@@ -587,61 +582,6 @@ LogicalResult vectorizeAsLinalgGeneric(
587582
/// This helper is needed atm because the truly generic implementation requires
588583
/// good vector.multi_reduce folding patterns that are currently NYI.
589584
// TODO: drop reliance on a specific pattern.
590-
static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
591-
SmallVectorImpl<Value> &newResults) {
592-
assert(isaContractionOpInterface(linalgOp) &&
593-
"expected vectorizeContraction preconditions to be met");
594-
Location loc = linalgOp.getLoc();
595-
// Vectorize other ops as vector contraction.
596-
// TODO: interface.
597-
LDBG(""
598-
<< "Rewrite linalg op as vector.contract: ";
599-
linalgOp.dump());
600-
// Special function that describes how to vectorize the multiplication op in a
601-
// linalg contraction.
602-
CustomVectorizationHook vectorizeContraction =
603-
[&](Operation *op,
604-
const BlockAndValueMapping &bvm) -> VectorizationResult {
605-
if (!isa<arith::MulIOp, arith::MulFOp>(op))
606-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
607-
ArrayRef<int64_t> outShape =
608-
linalgOp.getShape(linalgOp.getOutputOperand(0));
609-
Type vType;
610-
if (outShape.empty()) {
611-
vType = op->getResult(0).getType();
612-
} else {
613-
SmallVector<int64_t> resultShape = applyPermutationMap(
614-
inversePermutation(reindexIndexingMap(
615-
linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))),
616-
outShape);
617-
vType = VectorType::get(resultShape, op->getResult(0).getType());
618-
}
619-
auto zero = b.create<arith::ConstantOp>(loc, vType, b.getZeroAttr(vType));
620-
// Indexing maps at the time of vector.transfer_read are adjusted to order
621-
// vector dimensions in the same order as the canonical linalg op iteration
622-
// space order.
623-
// The indexings for the contraction therefore need to be adjusted.
624-
// TODO: consider dropping contraction special casing altogether, this will
625-
// require more advanced canonicalizations involving vector.multi_reduction
626-
// that are not yet available.
627-
SmallVector<AffineMap> indexingMaps;
628-
indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
629-
llvm::transform(linalgOp.getIndexingMaps(),
630-
std::back_inserter(indexingMaps),
631-
[](AffineMap indexingMap) {
632-
return inversePermutation(reindexIndexingMap(indexingMap))
633-
.compose(indexingMap);
634-
});
635-
Operation *contract = b.create<vector::ContractionOp>(
636-
loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
637-
b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
638-
return VectorizationResult{VectorizationStatus::NewOp, contract};
639-
};
640-
return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
641-
/*broadcastToMaximalCommonShape=*/false,
642-
{vectorizeContraction});
643-
}
644-
645585
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
646586
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
647587
return m.isProjectedPermutation(/*allowZerosInResults=*/true);
@@ -674,8 +614,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
674614
}
675615
if (isElementwise(op))
676616
return success();
677-
if (isaContractionOpInterface(linalgOp))
678-
return success();
679617
// TODO: isaConvolutionOpInterface that can also infer from generic features.
680618
// But we will still need stride/dilation attributes that will be annoying to
681619
// reverse-engineer...
@@ -702,8 +640,6 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
702640
return failure();
703641

704642
auto linalgOp = cast<LinalgOp>(op);
705-
if (isaContractionOpInterface(linalgOp))
706-
return vectorizeContraction(b, linalgOp, newResults);
707643

708644
// TODO: isaConvolutionOpInterface that can also infer from generic features.
709645
// But we will still need stride/dilation attributes that will be annoying to
@@ -721,8 +657,7 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
721657
<< "Vectorize linalg op as a generic by broadcasting to "
722658
"maximal common shape: "
723659
<< *op);
724-
return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
725-
/*broadcastToMaximalCommonShape=*/true);
660+
return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
726661
}
727662

728663
//----------------------------------------------------------------------------//

mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
2525
//
2626
// CHECK-1D: vector.contract
2727
// CHECK-1D-SAME: iterator_types = ["parallel", "parallel", "reduction"]
28-
// CHECK-1D-SAME: : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32>
28+
// CHECK-1D-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
2929
//
3030
// CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32, #{{.*}}>, vector<8x12xf32>
3131
// CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32, #{{.*}}>
@@ -41,6 +41,6 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
4141
//
4242
// CHECK-2D: vector.contract
4343
// CHECK-2D-SAME: iterator_types = ["parallel", "parallel", "reduction"]
44-
// CHECK-2D-SAME: : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32>
44+
// CHECK-2D-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
4545
//
4646
// CHECK-2D: linalg.copy

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
// CHECK-LABEL: contraction_dot
66
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
7-
// CHECK: vector.contract
8-
// CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32
7+
8+
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
9+
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
10+
// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
911
linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
1012
outs(%C: memref<f32>)
1113
return
@@ -15,8 +17,10 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
1517

1618
// CHECK-LABEL: contraction_matvec
1719
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
18-
// CHECK: vector.contract
19-
// CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
20+
21+
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
22+
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
23+
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
2024
linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
2125
outs(%C: memref<1584xf32>)
2226
return
@@ -26,8 +30,9 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
2630

2731
// CHECK-LABEL: contraction_matmul
2832
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
29-
// CHECK: vector.contract
30-
// CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
33+
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
34+
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
35+
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
3136
linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
3237
outs(%C: memref<1584x1584xf32>)
3338
return
@@ -37,8 +42,9 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
3742

3843
// CHECK-LABEL: contraction_batch_matmul
3944
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
40-
// CHECK: vector.contract
41-
// CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
45+
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
46+
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
47+
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
4248
linalg.batch_matmul
4349
ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
4450
outs(%C: memref<1584x1584x1584xf32>)
@@ -58,19 +64,15 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
5864
iterator_types = ["parallel", "parallel", "reduction"]
5965
}
6066

61-
// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
62-
// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
63-
// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
64-
// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
65-
6667
// CHECK-LABEL: func @vectorization_test
6768
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
6869
%C: memref<8x32xf32>) {
69-
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
70-
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
70+
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
71+
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
7172
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
72-
// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
73-
// CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
73+
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
74+
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
75+
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
7476
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
7577
linalg.generic #matmul_trait
7678
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -96,19 +98,15 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
9698
iterator_types = ["parallel", "parallel", "reduction"]
9799
}
98100

99-
// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
100-
// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
101-
// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
102-
// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
103-
104101
// CHECK-LABEL: func @generic_output_transpose
105102
func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
106103
%C: memref<32x8xf32>) {
107-
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
108-
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
104+
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
105+
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
109106
// CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
110-
// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
111-
// CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
107+
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
108+
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
109+
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
112110
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
113111
linalg.generic #matmul_transpose_out_trait
114112
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -134,19 +132,16 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
134132
iterator_types = ["parallel", "parallel", "reduction"]
135133
}
136134

137-
// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
138-
// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
139-
// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
140-
// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
141-
142135
// CHECK-LABEL: func @vectorization_test_integer
143136
func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
144137
%C: memref<8x32xi32>) {
145-
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
146-
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32>
138+
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32>
139+
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
147140
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
148-
// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]],
149-
// CHECK-SAME: vector<8x16xi32>, vector<32x16xi32> into vector<8x32xi32>
141+
// CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
142+
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
143+
// CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
144+
150145
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
151146
linalg.generic #matmul_trait
152147
ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
@@ -164,8 +159,9 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
164159
// CHECK-LABEL: func @vectorization_test_2
165160
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
166161
%C: memref<8x32xf32>) {
167-
// CHECK: vector.contract {{.*}} :
168-
// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
162+
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
163+
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
164+
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
169165
linalg.matmul
170166
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
171167
outs(%C: memref<8x32xf32>)
@@ -520,19 +516,16 @@ func @matmul_tensors(
520516
%arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
521517
-> tensor<8x12xf32> {
522518
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
523-
// CHECK-DAG: %[[VEC_C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x12xf32>
524-
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
525-
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32>
519+
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32>
520+
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32>
526521
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
527522
//
528-
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
529-
// a later canonicalization fuses the add into vector.contract.
530-
// CHECK: %[[C:.*]] = vector.contract
531-
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
532-
// CHECK-SAME: %[[V0]], %[[V1]], %[[VEC_C0]] :
533-
// CHECK-SAME: vector<8x4xf32>, vector<12x4xf32> into vector<8x12xf32>
534-
// CHECK: %[[C2:.*]] = arith.addf %[[V2]], %[[C]] : vector<8x12xf32>
535-
// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
523+
// linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
524+
// convert it to a 2D contract.
525+
// CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
526+
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
527+
// CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
528+
// CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
536529
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
537530
outs(%arg2: tensor<8x12xf32>)
538531
-> tensor<8x12xf32>

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,14 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
531531
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
532532
stage1Patterns);
533533
}
534+
{
535+
// Canonicalization patterns
536+
RewritePatternSet canonicalizationPatterns(funcOp.getContext());
537+
vector::populateVectorTransferPermutationMapLoweringPatterns(
538+
canonicalizationPatterns);
539+
vector::populateVetorReductionToContractPatterns(canonicalizationPatterns);
540+
stage1Patterns.push_back(std::move(canonicalizationPatterns));
541+
}
534542
SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
535543
llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
536544
FrozenRewritePatternSet stage2Patterns =

0 commit comments

Comments
 (0)