Skip to content

Commit a8df4b8

Browse files
committed
[mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect
This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`, which are thin wrappers around the corresponding LLVM intrinsics: - `llvm.intr.matrix.multiply` - `llvm.intr.matrix.transpose` These Vector dialect ops did not provide additional semantics or abstraction beyond the LLVM intrinsics. Their removal simplifies the lowering pipeline without losing any functionality. The lowering chains: - `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply` - `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose` are now replaced with: - `vector.contract` → `llvm.intr.matrix.multiply` - `vector.transpose` → `llvm.intr.matrix.transpose` This was accomplished by directly replacing: - `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp` - `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp` Note: This change introduces a build-time dependency from `Vector` to `LLVM`. Ideally, such dependencies should be confined to dialect conversion (`ConvertVectorToLLVM`). However, moving the lowering code there would introduce notable churn, so this patch leaves the new dependency in place for now.
1 parent 577199f commit a8df4b8

File tree

17 files changed

+28
-314
lines changed

17 files changed

+28
-314
lines changed

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
namespace mlir {
1414
class LLVMTypeConverter;
1515

16-
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
17-
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
18-
/// will be needed when invoking LLVM.
19-
void populateVectorToLLVMMatrixConversionPatterns(
20-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
21-
2216
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
2317
void populateVectorToLLVMConversionPatterns(
2418
const LLVMTypeConverter &converter, RewritePatternSet &patterns,

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2710,124 +2710,6 @@ def Vector_PrintOp :
27102710
}];
27112711
}
27122712

2713-
//===----------------------------------------------------------------------===//
2714-
// Ops used for supporting progressive lowering and conversion type changes.
2715-
// The Ops are typically not used directly by higher level dialects, but are
2716-
// used by intra-dialect rewriting rules to bring vector operations closer
2717-
// to the hardware ISA.
2718-
//===----------------------------------------------------------------------===//
2719-
2720-
/// Vector dialect matrix multiplication op that operates on flattened 1-D
2721-
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
2722-
/// This may seem redundant with vector.contract but it serves the purposes of
2723-
/// more progressive lowering and localized type conversion on the path:
2724-
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
2725-
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
2726-
PredOpTrait<"lhs operand and result have same element type",
2727-
TCresVTEtIsSameAsOpBase<0, 0>>,
2728-
PredOpTrait<"rhs operand and result have same element type",
2729-
TCresVTEtIsSameAsOpBase<0, 1>>]>,
2730-
Arguments<(
2731-
// TODO: tighten vector element types that make sense.
2732-
ins FixedVectorOfRankAndType<[1],
2733-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
2734-
FixedVectorOfRankAndType<[1],
2735-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
2736-
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
2737-
Results<(
2738-
outs FixedVectorOfRankAndType<[1],
2739-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
2740-
{
2741-
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
2742-
" MLIR vectors";
2743-
let description = [{
2744-
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
2745-
purposes of more progressive lowering and localized type conversion.
2746-
Higher levels typically lower matrix multiplications into 'vector.contract'
2747-
operations. Subsequent rewriting rule progressively lower these operations
2748-
into 'vector.matrix_multiply' operations to bring the operations closer
2749-
to the hardware ISA.
2750-
2751-
The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
2752-
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
2753-
<rhs_columns> and multiplies them. The result matrix is returned embedded in
2754-
the result vector.
2755-
2756-
Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
2757-
support scalable vectors. Hence, this Op is only available for fixed-width
2758-
vectors. Also see:
2759-
2760-
http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
2761-
2762-
Example:
2763-
2764-
```mlir
2765-
%C = vector.matrix_multiply %A, %B
2766-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
2767-
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
2768-
```
2769-
}];
2770-
let builders = [
2771-
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
2772-
"unsigned":$lhsColumns, "unsigned":$rhsColumns),
2773-
[{
2774-
$_state.addOperands({lhs, rhs});
2775-
$_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
2776-
$_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
2777-
$_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
2778-
$_state.addTypes(VectorType::get(lhsRows * rhsColumns,
2779-
::llvm::cast<VectorType>(lhs.getType()).getElementType()));
2780-
}]>,
2781-
];
2782-
let assemblyFormat = "$lhs `,` $rhs attr-dict "
2783-
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
2784-
}
2785-
2786-
/// Vector dialect matrix transposition op that operates on flattened 1-D
2787-
/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
2788-
/// This may seem redundant with vector.transpose but it serves the purposes of
2789-
/// more progressive lowering and localized type conversion on the path:
2790-
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
2791-
def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
2792-
PredOpTrait<"source operand and result have same element type",
2793-
TCresVTEtIsSameAsOpBase<0, 0>>]>,
2794-
Arguments<(
2795-
// TODO: tighten vector element types that make sense.
2796-
ins FixedVectorOfRankAndType<[1],
2797-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
2798-
I32Attr:$rows, I32Attr:$columns)>,
2799-
Results<(
2800-
outs FixedVectorOfRankAndType<[1],
2801-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
2802-
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
2803-
let description = [{
2804-
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
2805-
the purposes of more progressive lowering and localized type conversion.
2806-
Higher levels typically lower matrix transpositions into 'vector.transpose'
2807-
operations. Subsequent rewriting rule progressively lower these operations
2808-
into 'vector.flat_transpose' operations to bring the operations closer
2809-
to the hardware ISA.
2810-
2811-
The `vector.flat_transpose` op treats the 1-D input `matrix` as
2812-
a 2-D matrix with <rows> rows and <columns> columns, and returns the
2813-
transposed matrix in flattened form in 'res'.
2814-
2815-
Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
2816-
support scalable vectors. Hence, this Op is only available for fixed-width
2817-
vectors. Also see:
2818-
2819-
http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
2820-
2821-
Example:
2822-
2823-
```mlir
2824-
%1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
2825-
: vector<16xf32> -> vector<16xf32>
2826-
```
2827-
}];
2828-
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
2829-
}
2830-
28312713
//===----------------------------------------------------------------------===//
28322714
// SplatOp
28332715
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -184,41 +184,6 @@ class VectorBitCastOpConversion
184184
}
185185
};
186186

187-
/// Conversion pattern for a vector.matrix_multiply.
188-
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
189-
class VectorMatmulOpConversion
190-
: public ConvertOpToLLVMPattern<vector::MatmulOp> {
191-
public:
192-
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
193-
194-
LogicalResult
195-
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
196-
ConversionPatternRewriter &rewriter) const override {
197-
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
198-
matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
199-
adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
200-
matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
201-
return success();
202-
}
203-
};
204-
205-
/// Conversion pattern for a vector.flat_transpose.
206-
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
207-
class VectorFlatTransposeOpConversion
208-
: public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
209-
public:
210-
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
211-
212-
LogicalResult
213-
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
214-
ConversionPatternRewriter &rewriter) const override {
215-
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
216-
transOp, typeConverter->convertType(transOp.getRes().getType()),
217-
adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
218-
return success();
219-
}
220-
};
221-
222187
/// Overloaded utility that replaces a vector.load, vector.store,
223188
/// vector.maskedload and vector.maskedstore with their respective LLVM
224189
/// couterparts.
@@ -2026,12 +1991,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
20261991
VectorScalableStepOpLowering>(converter);
20271992
}
20281993

2029-
void mlir::populateVectorToLLVMMatrixConversionPatterns(
2030-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
2031-
patterns.add<VectorMatmulOpConversion>(converter);
2032-
patterns.add<VectorFlatTransposeOpConversion>(converter);
2033-
}
2034-
20351994
namespace {
20361995
struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
20371996
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9797
LLVMTypeConverter converter(&getContext(), options);
9898
RewritePatternSet patterns(&getContext());
9999
populateVectorTransferLoweringPatterns(patterns);
100-
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
101100
populateVectorToLLVMConversionPatterns(
102101
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
103102
useVectorAlignment);
104-
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
105103

106104
// Architecture specific augmentations.
107105
LLVMConversionTarget target(getContext());

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
118118
return converter.isLegal(op);
119119
});
120120
// Manually mark arithmetic-performing vector instructions.
121-
target.addDynamicallyLegalOp<
122-
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
123-
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
121+
target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
122+
vector::MultiDimReductionOp, vector::FMAOp,
123+
vector::OuterProductOp, vector::ScanOp>(
124124
[&](Operation *op) { return converter.isLegal(op); });
125125
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126126
arith::ConstantOp, vector::SplatOp>();

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
4949
MLIRTensorDialect
5050
MLIRTransforms
5151
MLIRVectorDialect
52+
MLIRLLVMDialect
5253
MLIRVectorInterfaces
5354
MLIRVectorUtils
5455
)

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Arith/Utils/Utils.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1718
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1819
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1920
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -1280,12 +1281,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
12801281
/// %mtb = maybe_transpose
12811282
/// %flattened_a = vector.shape_cast %mta
12821283
/// %flattened_b = vector.shape_cast %mtb
1283-
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1284+
/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
12841285
/// %mtd = vector.shape_cast %flattened_d
12851286
/// %d = maybe_untranspose %mtd
12861287
/// %e = add %c, %d
12871288
/// ```
1288-
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
12891289
//
12901290
/// This only kicks in when vectorContractLowering is set to `Matmul`.
12911291
/// vector.transpose operations are inserted if the vector.contract op is not a
@@ -1362,8 +1362,12 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13621362
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
13631363
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
13641364

1365-
Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1366-
rhsColumns);
1365+
Value mul = rew.create<LLVM::MatrixMultiplyOp>(
1366+
loc,
1367+
VectorType::get(lhsRows * rhsColumns,
1368+
cast<VectorType>(lhs.getType()).getElementType()),
1369+
lhs, rhs, lhsRows, lhsColumns, rhsColumns);
1370+
13671371
mul = rew.create<vector::ShapeCastOp>(
13681372
loc,
13691373
VectorType::get({lhsRows, rhsColumns},

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Dialect/UB/IR/UBOps.h"
1718
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -338,7 +339,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
338339
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
339340
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
340341
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
341-
Value trans = rewriter.create<vector::FlatTransposeOp>(
342+
Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
342343
loc, flattenedType, matrix, rows, columns);
343344
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
344345
return success();

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
14241424

14251425
return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
14261426
}
1427-
// -----
1428-
1429-
//===----------------------------------------------------------------------===//
1430-
// vector.matrix_multiply
1431-
//===----------------------------------------------------------------------===//
1432-
1433-
// 4x16 16x3 4x3
1434-
func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
1435-
%C = vector.matrix_multiply %A, %B
1436-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
1437-
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
1438-
return %C: vector<12xf64>
1439-
}
1440-
// CHECK-LABEL: @matrix_ops
1441-
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
1442-
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
1443-
// CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
1444-
1445-
// -----
1446-
1447-
func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
1448-
%C = vector.matrix_multiply %A, %B
1449-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
1450-
(vector<64xindex>, vector<48xindex>) -> vector<12xindex>
1451-
return %C: vector<12xindex>
1452-
}
1453-
// CHECK-LABEL: @matrix_ops_index
1454-
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
1455-
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
1456-
// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>
14571427

14581428
// -----
14591429

@@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> {
16021572

16031573
// -----
16041574

1605-
//===----------------------------------------------------------------------===//
1606-
// vector.flat_transpose
1607-
//===----------------------------------------------------------------------===//
1608-
1609-
func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
1610-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1611-
: vector<16xf32> -> vector<16xf32>
1612-
return %0 : vector<16xf32>
1613-
}
1614-
1615-
// CHECK-LABEL: func @flat_transpose
1616-
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
1617-
// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
1618-
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
1619-
// CHECK-SAME: vector<16xf32> into vector<16xf32>
1620-
// CHECK: return %[[T]] : vector<16xf32>
1621-
1622-
// -----
1623-
1624-
func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
1625-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1626-
: vector<16xindex> -> vector<16xindex>
1627-
return %0 : vector<16xindex>
1628-
}
1629-
// CHECK-LABEL: func @flat_transpose_index
1630-
// CHECK-SAME: %[[A:.*]]: vector<16xindex>
1631-
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64>
1632-
// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
1633-
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
1634-
// CHECK-SAME: vector<16xi64> into vector<16xi64>
1635-
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex>
1636-
// CHECK: return %[[T2]] : vector<16xindex>
1637-
1638-
// -----
1639-
1640-
func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
1641-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1642-
: vector<16xf32> -> vector<16xf32>
1643-
return %0 : vector<16xf32>
1644-
}
1645-
1646-
// CHECK-LABEL: func @flat_transpose
1647-
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
1648-
// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
1649-
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
1650-
// CHECK-SAME: vector<16xf32> into vector<16xf32>
1651-
// CHECK: return %[[T]] : vector<16xf32>
1652-
1653-
// -----
1654-
16551575
//===----------------------------------------------------------------------===//
16561576
// vector.gather
16571577
//

0 commit comments

Comments
 (0)