Skip to content

Commit ea1e181

Browse files
[MLIR][AArch64] Lower vector.contract with mixed signed/unsigned arguments to Neon FEAT_I8MM (#144698)
1 parent 36819ea commit ea1e181

File tree

7 files changed

+225
-48
lines changed

7 files changed

+225
-48
lines changed

mlir/include/mlir/Dialect/ArmNeon/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace mlir {
1313
class RewritePatternSet;
1414

1515
namespace arm_neon {
16-
void populateLowerContractionToSMMLAPatternPatterns(
16+
void populateLowerContractionToNeonI8MMPatternPatterns(
1717
RewritePatternSet &patterns);
1818
} // namespace arm_neon
1919

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8585
populateVectorGatherLoweringPatterns(patterns);
8686
if (armI8MM) {
8787
if (armNeon)
88-
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
88+
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
8989
if (armSVE)
9090
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
9191
}

mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using namespace mlir;
2020

2121
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
2222
RewritePatternSet &patterns) {
23-
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
23+
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
2424
}
2525

2626
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_mlir_dialect_library(MLIRArmNeonTransforms
2-
LowerContractionToSMMLAPattern.cpp
2+
LowerContractionToNeonI8MMPattern.cpp
33

44
DEPENDS
55
MLIRArmNeonIncGen

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp renamed to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp

Lines changed: 148 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
1+
//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements lowering patterns from vector.contract to
10-
// arm_neon.intr.smmla
9+
// This file implements lowering patterns from vector.contract to operations
10+
// that map to instructions from the Neon FEAT_I8MM extension.
1111
//
12-
//===---
12+
// TODO: There may be opportunities to unify this with a similar pattern
13+
// for SVE. See:
14+
// https://github.com/llvm/llvm-project/issues/145559
15+
// LowerContractionToSVEI8MMPattern.cpp
16+
//
17+
//===----------------------------------------------------------------------===//
1318

1419
#include "mlir/Dialect/Arith/IR/Arith.h"
1520
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -37,12 +42,87 @@ static Type matchContainerType(Type element, Type container) {
3742
return element;
3843
}
3944

45+
// Get the operand of a `vector.contract`. This function is intended to abstract
46+
// away from the particular way a value is extended before feeding it into the
47+
// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
48+
// (for implicit sign-extension see `vector.contract` documentation).
49+
//
50+
// The template parameter `Op` indicates the extension operation (explicit or
51+
// implicit) for which we are checking.
52+
//
53+
// Return success only for extensions from `iN` (N <= 8) to `i32`.
54+
template <typename Op>
55+
std::optional<Value> getExtOperand(Value v) {
56+
57+
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
58+
"Must be instantiated with either sign- or zero- extension op");
59+
60+
// If the operand is not defined by an explicit extend operation of the
61+
// accepted operation type allow for an implicit sign-extension.
62+
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
63+
if (!extOp) {
64+
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
65+
auto eltTy = cast<VectorType>(v.getType()).getElementType();
66+
if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
67+
return {};
68+
return v;
69+
}
70+
return {};
71+
}
72+
73+
// If the operand is defined by an explicit extend operation of the accepted
74+
// operation type, check it's extended from `iN` (N <= 8) to `i32`.
75+
auto inOp = extOp.getIn();
76+
auto inTy = dyn_cast<VectorType>(inOp.getType());
77+
if (!inTy)
78+
return {};
79+
auto inEltTy = inTy.getElementType();
80+
if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
81+
return {};
82+
83+
auto outTy = dyn_cast<VectorType>(extOp.getType());
84+
if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
85+
return {};
86+
87+
return inOp;
88+
}
89+
90+
// Designate the operation (resp. instruction) used to do sub-tile matrix
91+
// multiplications.
92+
enum class MMLA {
93+
Signed, // smmla
94+
Unsigned, // ummla
95+
Mixed, // usmmla
96+
MixedSwapped // usmmla with LHS and RHS swapped
97+
};
98+
99+
// Create the matrix mulitply and accumulate operation according to `op`.
100+
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
101+
mlir::Type accType, Value acc, Value lhs, Value rhs) {
102+
switch (op) {
103+
case MMLA::Signed:
104+
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
105+
rhs);
106+
case MMLA::Unsigned:
107+
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
108+
rhs);
109+
case MMLA::Mixed:
110+
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
111+
rhs);
112+
case MMLA::MixedSwapped:
113+
// The accumulator comes transposed and the result will be transposed
114+
// later, so all we have to do here is swap the operands.
115+
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
116+
lhs);
117+
}
118+
}
119+
40120
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41121
/// any vector.contract into multiple smmla instructions with unrolling so long
42122
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43123
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44124
/// necessary, a single smmla instruction is emitted.
45-
class LowerContractionToSMMLAPattern
125+
class LowerContractionToNeonI8MMPattern
46126
: public OpRewritePattern<vector::ContractionOp> {
47127
public:
48128
using OpRewritePattern::OpRewritePattern;
@@ -88,39 +168,64 @@ class LowerContractionToSMMLAPattern
88168
return failure();
89169
}
90170

91-
// Check two extsi inputs Rhs Lhs for contract.
92-
arith::ExtSIOp origLhsExtOp =
93-
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
94-
arith::ExtSIOp origRhsExtOp =
95-
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
96-
if (!origLhsExtOp || !origRhsExtOp) {
171+
// Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
172+
// values before the extension. All four signed/unsigned combinations for
173+
// input operands are supported, but they are lowered to different
174+
// operations. Determine which is the appropriate operation to lower to.
175+
MMLA mmlaOp = MMLA::Signed;
176+
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
177+
if (!maybeLhs) {
178+
mmlaOp = MMLA::Unsigned;
179+
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
180+
}
181+
if (!maybeLhs)
97182
return failure();
183+
184+
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
185+
if (maybeRhs) {
186+
if (mmlaOp == MMLA::Unsigned)
187+
mmlaOp = MMLA::Mixed;
188+
} else {
189+
if (mmlaOp == MMLA::Signed)
190+
mmlaOp = MMLA::MixedSwapped;
191+
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
98192
}
193+
if (!maybeRhs)
194+
return failure();
195+
196+
Value origLhs = *maybeLhs;
197+
Value origRhs = *maybeRhs;
99198

100199
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101200
// following neon instruction. Check inputs for extsi are <=i8
102-
Value extsiLhs;
103-
Value extsiRhs;
104-
if (auto lhsExtInType =
105-
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
201+
Value extLhs;
202+
Value extRhs;
203+
if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
106204
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
107205
Type targetLhsExtTy =
108206
matchContainerType(rewriter.getI8Type(), lhsExtInType);
109-
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
110-
origLhsExtOp.getIn());
207+
if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
208+
extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
209+
origLhs);
210+
else
211+
extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
212+
origLhs);
111213
}
112214
}
113-
if (auto rhsExtInType =
114-
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
215+
if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
115216
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
116217
Type targetRhsExtTy =
117218
matchContainerType(rewriter.getI8Type(), rhsExtInType);
118-
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
119-
origRhsExtOp.getIn());
219+
if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
220+
extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
221+
origRhs);
222+
else
223+
extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
224+
origRhs);
120225
}
121226
}
122227

123-
if (!extsiLhs || !extsiRhs) {
228+
if (!extLhs || !extRhs) {
124229
return failure();
125230
}
126231

@@ -155,11 +260,11 @@ class LowerContractionToSMMLAPattern
155260
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
156261
SmallVector<int64_t> lhsOffsets =
157262
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
158-
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
263+
Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
159264
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
160265
SmallVector<int64_t> rhsOffsets =
161266
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
162-
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
267+
Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
163268
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
164269
SmallVector<int64_t> accOffsets =
165270
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
@@ -191,6 +296,13 @@ class LowerContractionToSMMLAPattern
191296
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
192297
}
193298

299+
// Transpose ACC if doing signed by unsigned multiplication, because we're
300+
// using the instruction for unsigned by signed multiplication with
301+
// reversed operands.
302+
if (mmlaOp == MMLA::MixedSwapped)
303+
tiledAcc = rewriter.create<vector::TransposeOp>(
304+
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
305+
194306
// Collapse tiled operands to 1D vectors required by smmla intrinsic
195307
auto collapsedInputType =
196308
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
@@ -211,15 +323,21 @@ class LowerContractionToSMMLAPattern
211323
}
212324

213325
// Insert contract op
214-
kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
215-
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
216-
collapsedRhs);
326+
kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
327+
collapsedRes, collapsedLhs, collapsedRhs);
217328

218329
// Reshape output back to 2D
219330
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
220331
kAcc.getLoc(), tiledAcc.getType(), kAcc);
221332

222-
// With vecmat, only one row of tiled ACC can be inserted into file result
333+
// Because of the reversed operands the result is obtained transposed.
334+
// Transpose it back,
335+
if (mmlaOp == MMLA::MixedSwapped)
336+
tiledRes = rewriter.create<vector::TransposeOp>(
337+
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
338+
339+
// With vecmat, only one row of tiled ACC can be inserted into the final
340+
// result
223341
if (isVecmat) {
224342
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
225343
}
@@ -239,8 +357,8 @@ class LowerContractionToSMMLAPattern
239357

240358
} // namespace
241359

242-
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
360+
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
243361
RewritePatternSet &patterns) {
244362
MLIRContext *context = patterns.getContext();
245-
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
363+
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
246364
}

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
1+
//===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,6 +9,11 @@
99
// This file implements lowering patterns from vector.contract to operations
1010
// that map to instructions from the SVE FEAT_I8MM extension.
1111
//
12+
// TODO: There may be opportunities to unify this with a similar pattern
13+
// for Neon. See:
14+
// https://github.com/llvm/llvm-project/issues/145559
15+
// LowerContractionToNeonI8MMPattern.cpp
16+
//
1217
//===----------------------------------------------------------------------===//
1318

1419
#include "mlir/Dialect/Arith/IR/Arith.h"

0 commit comments

Comments
 (0)