@@ -2000,13 +2000,217 @@ struct VectorScalableStepOpLowering
2000
2000
}
2001
2001
};
2002
2002
2003
+ // / Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2004
+ // / semantics to:
2005
+ // / ```
2006
+ // / %flattened_a = vector.shape_cast %a
2007
+ // / %flattened_b = vector.shape_cast %b
2008
+ // / %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
2009
+ // / %d = vector.shape_cast %%flattened_d
2010
+ // / %e = add %c, %d
2011
+ // / ```
2012
+ // / `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
2013
+ //
2014
+ // / This only kicks in when vectorContractLowering is set to Matmul and
2015
+ // / the vector.contract op is a row-major matrix multiply.
2016
+ class ContractionOpToMatmulOpLowering
2017
+ : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2018
+ public:
2019
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
2020
+
2021
+ ContractionOpToMatmulOpLowering (
2022
+ vector::VectorContractLowering vectorContractLowering,
2023
+ MLIRContext *context, PatternBenefit benefit = 100 )
2024
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2025
+
2026
+ FailureOr<Value>
2027
+ matchAndRewriteMaskableOp (vector::ContractionOp op, MaskingOpInterface maskOp,
2028
+ PatternRewriter &rewriter) const override ;
2029
+ };
2030
+
2031
+ // / Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
2032
+ // / semantics to:
2033
+ // / ```
2034
+ // / %mta = maybe_transpose
2035
+ // / %mtb = maybe_transpose
2036
+ // / %flattened_a = vector.shape_cast %mta
2037
+ // / %flattened_b = vector.shape_cast %mtb
2038
+ // / %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
2039
+ // / %mtd = vector.shape_cast %flattened_d
2040
+ // / %d = maybe_untranspose %mtd
2041
+ // / %e = add %c, %d
2042
+ // / ```
2043
+ //
2044
+ // / This only kicks in when vectorContractLowering is set to `Matmul`.
2045
+ // / vector.transpose operations are inserted if the vector.contract op is not a
2046
+ // / row-major matrix multiply.
2047
+ // /
2048
+ // / Scalable vectors are not supported.
2049
+ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp (
2050
+ vector::ContractionOp op, MaskingOpInterface maskOp,
2051
+ PatternRewriter &rew) const {
2052
+ // TODO: Support vector.mask.
2053
+ if (maskOp)
2054
+ return failure ();
2055
+
2056
+ auto iteratorTypes = op.getIteratorTypes ().getValue ();
2057
+ if (!isParallelIterator (iteratorTypes[0 ]) ||
2058
+ !isParallelIterator (iteratorTypes[1 ]) ||
2059
+ !isReductionIterator (iteratorTypes[2 ]))
2060
+ return failure ();
2061
+
2062
+ Type opResType = op.getType ();
2063
+ VectorType vecType = dyn_cast<VectorType>(opResType);
2064
+ if (vecType && vecType.isScalable ()) {
2065
+ // Note - this is sufficient to reject all cases with scalable vectors.
2066
+ return failure ();
2067
+ }
2068
+
2069
+ Type elementType = op.getLhsType ().getElementType ();
2070
+ if (!elementType.isIntOrFloat ())
2071
+ return failure ();
2072
+
2073
+ Type dstElementType = vecType ? vecType.getElementType () : opResType;
2074
+ if (elementType != dstElementType)
2075
+ return failure ();
2076
+
2077
+ // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2078
+ // Bail out if the contraction cannot be put in this form.
2079
+ MLIRContext *ctx = op.getContext ();
2080
+ Location loc = op.getLoc ();
2081
+ AffineExpr m, n, k;
2082
+ bindDims (rew.getContext (), m, n, k);
2083
+ // LHS must be A(m, k) or A(k, m).
2084
+ Value lhs = op.getLhs ();
2085
+ auto lhsMap = op.getIndexingMapsArray ()[0 ];
2086
+ if (lhsMap == AffineMap::get (3 , 0 , {k, m}, ctx))
2087
+ lhs = rew.create <vector::TransposeOp>(loc, lhs, ArrayRef<int64_t >{1 , 0 });
2088
+ else if (lhsMap != AffineMap::get (3 , 0 , {m, k}, ctx))
2089
+ return failure ();
2090
+
2091
+ // RHS must be B(k, n) or B(n, k).
2092
+ Value rhs = op.getRhs ();
2093
+ auto rhsMap = op.getIndexingMapsArray ()[1 ];
2094
+ if (rhsMap == AffineMap::get (3 , 0 , {n, k}, ctx))
2095
+ rhs = rew.create <vector::TransposeOp>(loc, rhs, ArrayRef<int64_t >{1 , 0 });
2096
+ else if (rhsMap != AffineMap::get (3 , 0 , {k, n}, ctx))
2097
+ return failure ();
2098
+
2099
+ // At this point lhs and rhs are in row-major.
2100
+ VectorType lhsType = cast<VectorType>(lhs.getType ());
2101
+ VectorType rhsType = cast<VectorType>(rhs.getType ());
2102
+ int64_t lhsRows = lhsType.getDimSize (0 );
2103
+ int64_t lhsColumns = lhsType.getDimSize (1 );
2104
+ int64_t rhsColumns = rhsType.getDimSize (1 );
2105
+
2106
+ Type flattenedLHSType =
2107
+ VectorType::get (lhsType.getNumElements (), lhsType.getElementType ());
2108
+ lhs = rew.create <vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
2109
+
2110
+ Type flattenedRHSType =
2111
+ VectorType::get (rhsType.getNumElements (), rhsType.getElementType ());
2112
+ rhs = rew.create <vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
2113
+
2114
+ Value mul = rew.create <LLVM::MatrixMultiplyOp>(
2115
+ loc,
2116
+ VectorType::get (lhsRows * rhsColumns,
2117
+ cast<VectorType>(lhs.getType ()).getElementType ()),
2118
+ lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2119
+
2120
+ mul = rew.create <vector::ShapeCastOp>(
2121
+ loc,
2122
+ VectorType::get ({lhsRows, rhsColumns},
2123
+ getElementTypeOrSelf (op.getAcc ().getType ())),
2124
+ mul);
2125
+
2126
+ // ACC must be C(m, n) or C(n, m).
2127
+ auto accMap = op.getIndexingMapsArray ()[2 ];
2128
+ if (accMap == AffineMap::get (3 , 0 , {n, m}, ctx))
2129
+ mul = rew.create <vector::TransposeOp>(loc, mul, ArrayRef<int64_t >{1 , 0 });
2130
+ else if (accMap != AffineMap::get (3 , 0 , {m, n}, ctx))
2131
+ llvm_unreachable (" invalid contraction semantics" );
2132
+
2133
+ Value res =
2134
+ isa<IntegerType>(elementType)
2135
+ ? static_cast <Value>(rew.create <arith::AddIOp>(loc, op.getAcc (), mul))
2136
+ : static_cast <Value>(
2137
+ rew.create <arith::AddFOp>(loc, op.getAcc (), mul));
2138
+
2139
+ return res;
2140
+ }
2141
+
2142
+ // / Progressive lowering of TransposeOp.
2143
+ // / One:
2144
+ // / %x = vector.transpose %y, [1, 0]
2145
+ // / is replaced by:
2146
+ // / %z = arith.constant dense<0.000000e+00>
2147
+ // / %0 = vector.extract %y[0, 0]
2148
+ // / %1 = vector.insert %0, %z [0, 0]
2149
+ // / ..
2150
+ // / %x = vector.insert .., .. [.., ..]
2151
+ class TransposeOpLowering : public OpRewritePattern <vector::TransposeOp> {
2152
+ public:
2153
+ using OpRewritePattern<TransposeOp>::OpRewritePattern;
2154
+
2155
+ LogicalResult matchAndRewrite (vector::TransposeOp op,
2156
+ PatternRewriter &rewriter) const override {
2157
+ auto loc = op.getLoc ();
2158
+
2159
+ Value input = op.getVector ();
2160
+ VectorType inputType = op.getSourceVectorType ();
2161
+ VectorType resType = op.getResultVectorType ();
2162
+
2163
+ if (inputType.isScalable ())
2164
+ return rewriter.notifyMatchFailure (
2165
+ op, " This lowering does not support scalable vectors" );
2166
+
2167
+ // Set up convenience transposition table.
2168
+ ArrayRef<int64_t > transp = op.getPermutation ();
2169
+
2170
+ if (resType.getRank () != 2 || transp[0 ] != 1 || transp[1 ] != 0 ) {
2171
+ return failure ();
2172
+ }
2173
+
2174
+ Type flattenedType =
2175
+ VectorType::get (resType.getNumElements (), resType.getElementType ());
2176
+ auto matrix =
2177
+ rewriter.create <vector::ShapeCastOp>(loc, flattenedType, input);
2178
+ auto rows = rewriter.getI32IntegerAttr (resType.getShape ()[0 ]);
2179
+ auto columns = rewriter.getI32IntegerAttr (resType.getShape ()[1 ]);
2180
+ Value trans = rewriter.create <LLVM::MatrixTransposeOp>(
2181
+ loc, flattenedType, matrix, rows, columns);
2182
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, resType, trans);
2183
+ return success ();
2184
+ }
2185
+ };
2186
+
2003
2187
} // namespace
2004
2188
2005
2189
void mlir::vector::populateVectorRankReducingFMAPattern (
2006
2190
RewritePatternSet &patterns) {
2007
2191
patterns.add <VectorFMAOpNDRewritePattern>(patterns.getContext ());
2008
2192
}
2009
2193
2194
+ // / Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
2195
+ // /
2196
+ // / Given the high benefit, this will be prioriotised over other
2197
+ // / contract-lowering patterns. As such, the convert-vector-to-llvm pass will
2198
+ // / only run this registration conditionally.
2199
+ void mlir::vector::populateVectorContractToMatrixMultiply (
2200
+ RewritePatternSet &patterns) {
2201
+ patterns.add <ContractionOpToMatmulOpLowering>(patterns.getContext (), 100 );
2202
+ }
2203
+
2204
+ // / Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
2205
+ // /
2206
+ // / Given the high benefit, this will be prioriotised over other
2207
+ // / transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
2208
+ // / only run this registration conditionally.
2209
+ void mlir::vector::populateVectorTransposeToFlatTranspose (
2210
+ RewritePatternSet &patterns) {
2211
+ patterns.add <TransposeOpLowering>(patterns.getContext (), 100 );
2212
+ }
2213
+
2010
2214
// / Populate the given list with patterns that convert from Vector to LLVM.
2011
2215
void mlir::populateVectorToLLVMConversionPatterns (
2012
2216
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
0 commit comments