Skip to content

[mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect #144307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
namespace mlir {
class LLVMTypeConverter;

/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
void populateVectorToLLVMMatrixConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
Expand Down
118 changes: 0 additions & 118 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2788,124 +2788,6 @@ def Vector_PrintOp :
}];
}

//===----------------------------------------------------------------------===//
// Ops used for supporting progressive lowering and conversion type changes.
// The Ops are typically not used directly by higher level dialects, but are
// used by intra-dialect rewriting rules to bring vector operations closer
// to the hardware ISA.
//===----------------------------------------------------------------------===//

/// Vector dialect matrix multiplication op that operates on flattened 1-D
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
/// This may seem redundant with vector.contract but it serves the purposes of
/// more progressive lowering and localized type conversion on the path:
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
PredOpTrait<"lhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"rhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(
// TODO: tighten vector element types that make sense.
ins FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
Results<(
outs FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
{
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
" MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
purposes of more progressive lowering and localized type conversion.
Higher levels typically lower matrix multiplications into 'vector.contract'
operations. Subsequent rewriting rule progressively lower these operations
into 'vector.matrix_multiply' operations to bring the operations closer
to the hardware ISA.

The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
<rhs_columns> and multiplies them. The result matrix is returned embedded in
the result vector.

Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
support scalable vectors. Hence, this Op is only available for fixed-width
vectors. Also see:

http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic

Example:

```mlir
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
```
}];
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
"unsigned":$lhsColumns, "unsigned":$rhsColumns),
[{
$_state.addOperands({lhs, rhs});
$_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
$_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
$_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
$_state.addTypes(VectorType::get(lhsRows * rhsColumns,
::llvm::cast<VectorType>(lhs.getType()).getElementType()));
}]>,
];
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
}

/// Vector dialect matrix transposition op that operates on flattened 1-D
/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
/// This may seem redundant with vector.transpose but it serves the purposes of
/// more progressive lowering and localized type conversion on the path:
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(
// TODO: tighten vector element types that make sense.
ins FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
I32Attr:$rows, I32Attr:$columns)>,
Results<(
outs FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
the purposes of more progressive lowering and localized type conversion.
Higher levels typically lower matrix transpositions into 'vector.transpose'
operations. Subsequent rewriting rule progressively lower these operations
into 'vector.flat_transpose' operations to bring the operations closer
to the hardware ISA.

The `vector.flat_transpose` op treats the 1-D input `matrix` as
a 2-D matrix with <rows> rows and <columns> columns, and returns the
transposed matrix in flattened form in 'res'.

Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
support scalable vectors. Hence, this Op is only available for fixed-width
vectors. Also see:

http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic

Example:

```mlir
%1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
: vector<16xf32> -> vector<16xf32>
```
}];
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
}

//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);

/// TODO
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns);

/// TODO
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns);

} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
Loading
Loading