Skip to content

Commit 1ad2ef4

Browse files
committed
fixup! [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect
Revert the dependency of the Vector transforms on the LLVM Dialect.
1 parent bd5e9ad commit 1ad2ef4

File tree

9 files changed

+287
-258
lines changed

9 files changed

+287
-258
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
297297
/// n > 1.
298298
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
299299

300+
/// TODO
301+
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns);
302+
303+
/// TODO
304+
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns);
305+
300306
} // namespace vector
301307
} // namespace mlir
302308
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,13 +2000,217 @@ struct VectorScalableStepOpLowering
20002000
}
20012001
};
20022002

2003+
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2004+
/// semantics to:
2005+
/// ```
2006+
/// %flattened_a = vector.shape_cast %a
2007+
/// %flattened_b = vector.shape_cast %b
2008+
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
2009+
/// %d = vector.shape_cast %%flattened_d
2010+
/// %e = add %c, %d
2011+
/// ```
2012+
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
2013+
//
2014+
/// This only kicks in when vectorContractLowering is set to Matmul and
2015+
/// the vector.contract op is a row-major matrix multiply.
2016+
class ContractionOpToMatmulOpLowering
2017+
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2018+
public:
2019+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
2020+
2021+
ContractionOpToMatmulOpLowering(
2022+
vector::VectorContractLowering vectorContractLowering,
2023+
MLIRContext *context, PatternBenefit benefit = 100)
2024+
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2025+
2026+
FailureOr<Value>
2027+
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2028+
PatternRewriter &rewriter) const override;
2029+
};
2030+
2031+
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
2032+
/// semantics to:
2033+
/// ```
2034+
/// %mta = maybe_transpose
2035+
/// %mtb = maybe_transpose
2036+
/// %flattened_a = vector.shape_cast %mta
2037+
/// %flattened_b = vector.shape_cast %mtb
2038+
/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
2039+
/// %mtd = vector.shape_cast %flattened_d
2040+
/// %d = maybe_untranspose %mtd
2041+
/// %e = add %c, %d
2042+
/// ```
2043+
//
2044+
/// This only kicks in when vectorContractLowering is set to `Matmul`.
2045+
/// vector.transpose operations are inserted if the vector.contract op is not a
2046+
/// row-major matrix multiply.
2047+
///
2048+
/// Scalable vectors are not supported.
2049+
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2050+
vector::ContractionOp op, MaskingOpInterface maskOp,
2051+
PatternRewriter &rew) const {
2052+
// TODO: Support vector.mask.
2053+
if (maskOp)
2054+
return failure();
2055+
2056+
auto iteratorTypes = op.getIteratorTypes().getValue();
2057+
if (!isParallelIterator(iteratorTypes[0]) ||
2058+
!isParallelIterator(iteratorTypes[1]) ||
2059+
!isReductionIterator(iteratorTypes[2]))
2060+
return failure();
2061+
2062+
Type opResType = op.getType();
2063+
VectorType vecType = dyn_cast<VectorType>(opResType);
2064+
if (vecType && vecType.isScalable()) {
2065+
// Note - this is sufficient to reject all cases with scalable vectors.
2066+
return failure();
2067+
}
2068+
2069+
Type elementType = op.getLhsType().getElementType();
2070+
if (!elementType.isIntOrFloat())
2071+
return failure();
2072+
2073+
Type dstElementType = vecType ? vecType.getElementType() : opResType;
2074+
if (elementType != dstElementType)
2075+
return failure();
2076+
2077+
// Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2078+
// Bail out if the contraction cannot be put in this form.
2079+
MLIRContext *ctx = op.getContext();
2080+
Location loc = op.getLoc();
2081+
AffineExpr m, n, k;
2082+
bindDims(rew.getContext(), m, n, k);
2083+
// LHS must be A(m, k) or A(k, m).
2084+
Value lhs = op.getLhs();
2085+
auto lhsMap = op.getIndexingMapsArray()[0];
2086+
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
2087+
lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
2088+
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
2089+
return failure();
2090+
2091+
// RHS must be B(k, n) or B(n, k).
2092+
Value rhs = op.getRhs();
2093+
auto rhsMap = op.getIndexingMapsArray()[1];
2094+
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
2095+
rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
2096+
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
2097+
return failure();
2098+
2099+
// At this point lhs and rhs are in row-major.
2100+
VectorType lhsType = cast<VectorType>(lhs.getType());
2101+
VectorType rhsType = cast<VectorType>(rhs.getType());
2102+
int64_t lhsRows = lhsType.getDimSize(0);
2103+
int64_t lhsColumns = lhsType.getDimSize(1);
2104+
int64_t rhsColumns = rhsType.getDimSize(1);
2105+
2106+
Type flattenedLHSType =
2107+
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2108+
lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
2109+
2110+
Type flattenedRHSType =
2111+
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2112+
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
2113+
2114+
Value mul = rew.create<LLVM::MatrixMultiplyOp>(
2115+
loc,
2116+
VectorType::get(lhsRows * rhsColumns,
2117+
cast<VectorType>(lhs.getType()).getElementType()),
2118+
lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2119+
2120+
mul = rew.create<vector::ShapeCastOp>(
2121+
loc,
2122+
VectorType::get({lhsRows, rhsColumns},
2123+
getElementTypeOrSelf(op.getAcc().getType())),
2124+
mul);
2125+
2126+
// ACC must be C(m, n) or C(n, m).
2127+
auto accMap = op.getIndexingMapsArray()[2];
2128+
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
2129+
mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
2130+
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
2131+
llvm_unreachable("invalid contraction semantics");
2132+
2133+
Value res =
2134+
isa<IntegerType>(elementType)
2135+
? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
2136+
: static_cast<Value>(
2137+
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
2138+
2139+
return res;
2140+
}
2141+
2142+
/// Progressive lowering of TransposeOp.
2143+
/// One:
2144+
/// %x = vector.transpose %y, [1, 0]
2145+
/// is replaced by:
2146+
/// %z = arith.constant dense<0.000000e+00>
2147+
/// %0 = vector.extract %y[0, 0]
2148+
/// %1 = vector.insert %0, %z [0, 0]
2149+
/// ..
2150+
/// %x = vector.insert .., .. [.., ..]
2151+
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
2152+
public:
2153+
using OpRewritePattern<TransposeOp>::OpRewritePattern;
2154+
2155+
LogicalResult matchAndRewrite(vector::TransposeOp op,
2156+
PatternRewriter &rewriter) const override {
2157+
auto loc = op.getLoc();
2158+
2159+
Value input = op.getVector();
2160+
VectorType inputType = op.getSourceVectorType();
2161+
VectorType resType = op.getResultVectorType();
2162+
2163+
if (inputType.isScalable())
2164+
return rewriter.notifyMatchFailure(
2165+
op, "This lowering does not support scalable vectors");
2166+
2167+
// Set up convenience transposition table.
2168+
ArrayRef<int64_t> transp = op.getPermutation();
2169+
2170+
if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2171+
return failure();
2172+
}
2173+
2174+
Type flattenedType =
2175+
VectorType::get(resType.getNumElements(), resType.getElementType());
2176+
auto matrix =
2177+
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
2178+
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
2179+
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
2180+
Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
2181+
loc, flattenedType, matrix, rows, columns);
2182+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
2183+
return success();
2184+
}
2185+
};
2186+
20032187
} // namespace
20042188

20052189
void mlir::vector::populateVectorRankReducingFMAPattern(
20062190
RewritePatternSet &patterns) {
20072191
patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
20082192
}
20092193

2194+
/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
2195+
///
2196+
/// Given the high benefit, this will be prioriotised over other
2197+
/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
2198+
/// only run this registration conditionally.
2199+
void mlir::vector::populateVectorContractToMatrixMultiply(
2200+
RewritePatternSet &patterns) {
2201+
patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), 100);
2202+
}
2203+
2204+
/// Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
2205+
///
2206+
/// Given the high benefit, this will be prioriotised over other
2207+
/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
2208+
/// only run this registration conditionally.
2209+
void mlir::vector::populateVectorTransposeToFlatTranspose(
2210+
RewritePatternSet &patterns) {
2211+
patterns.add<TransposeOpLowering>(patterns.getContext(), 100);
2212+
}
2213+
20102214
/// Populate the given list with patterns that convert from Vector to LLVM.
20112215
void mlir::populateVectorToLLVMConversionPatterns(
20122216
const LLVMTypeConverter &converter, RewritePatternSet &patterns,

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,17 @@ void ConvertVectorToLLVMPass::runOnOperation() {
7070
populateVectorBitCastLoweringPatterns(patterns);
7171
populateVectorBroadcastLoweringPatterns(patterns);
7272
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
73+
if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
74+
populateVectorContractToMatrixMultiply(patterns);
75+
}
7376
populateVectorMaskOpLoweringPatterns(patterns);
7477
populateVectorShapeCastLoweringPatterns(patterns);
7578
populateVectorInterleaveLoweringPatterns(patterns);
7679
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
80+
if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
81+
populateVectorTransposeToFlatTranspose(patterns);
82+
}
83+
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
7784
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
7885
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
7986
populateVectorMaskMaterializationPatterns(patterns,

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
4949
MLIRTensorDialect
5050
MLIRTransforms
5151
MLIRVectorDialect
52-
MLIRLLVMDialect
5352
MLIRVectorInterfaces
5453
MLIRVectorUtils
5554
)

0 commit comments

Comments
 (0)