1
- // ===- LowerContractionToSMMLAPattern .cpp - Contract to SMMLA -- -*- C++ -*-===//
1
+ // ===- LowerContractionToNeonI8MMPattern .cpp - Contract to I8MM -*- C++ -*-===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file implements lowering patterns from vector.contract to
10
- // arm_neon.intr.smmla
9
+ // This file implements lowering patterns from vector.contract to operations
10
+ // that map to instructions from the Neon FEAT_I8MM extension.
11
11
//
12
- // ===---
12
+ // TODO: There may be opportunities to unify this with a similar pattern
13
+ // for SVE. See:
14
+ // https://github.com/llvm/llvm-project/issues/145559
15
+ // LowerContractionToSVEI8MMPattern.cpp
16
+ //
17
+ // ===----------------------------------------------------------------------===//
13
18
14
19
#include " mlir/Dialect/Arith/IR/Arith.h"
15
20
#include " mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -37,12 +42,87 @@ static Type matchContainerType(Type element, Type container) {
37
42
return element;
38
43
}
39
44
45
+ // Get the operand of a `vector.contract`. This function is intended to abstract
46
+ // away from the particular way a value is extended before feeding it into the
47
+ // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
48
+ // (for implicit sign-extension see `vector.contract` documentation).
49
+ //
50
+ // The template parameter `Op` indicates the extension operation (explicit or
51
+ // implicit) for which we are checking.
52
+ //
53
+ // Return success only for extensions from `iN` (N <= 8) to `i32`.
54
+ template <typename Op>
55
+ std::optional<Value> getExtOperand (Value v) {
56
+
57
+ static_assert (llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
58
+ " Must be instantiated with either sign- or zero- extension op" );
59
+
60
+ // If the operand is not defined by an explicit extend operation of the
61
+ // accepted operation type allow for an implicit sign-extension.
62
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp ());
63
+ if (!extOp) {
64
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
65
+ auto eltTy = cast<VectorType>(v.getType ()).getElementType ();
66
+ if (!eltTy.isSignlessInteger () || eltTy.getIntOrFloatBitWidth () > 8 )
67
+ return {};
68
+ return v;
69
+ }
70
+ return {};
71
+ }
72
+
73
+ // If the operand is defined by an explicit extend operation of the accepted
74
+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
75
+ auto inOp = extOp.getIn ();
76
+ auto inTy = dyn_cast<VectorType>(inOp.getType ());
77
+ if (!inTy)
78
+ return {};
79
+ auto inEltTy = inTy.getElementType ();
80
+ if (!inEltTy.isSignlessInteger () || inEltTy.getIntOrFloatBitWidth () > 8 )
81
+ return {};
82
+
83
+ auto outTy = dyn_cast<VectorType>(extOp.getType ());
84
+ if (!(outTy && outTy.getElementType ().isSignlessInteger (32 )))
85
+ return {};
86
+
87
+ return inOp;
88
+ }
89
+
90
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
91
+ // multiplications.
92
+ enum class MMLA {
93
+ Signed, // smmla
94
+ Unsigned, // ummla
95
+ Mixed, // usmmla
96
+ MixedSwapped // usmmla with LHS and RHS swapped
97
+ };
98
+
99
+ // Create the matrix mulitply and accumulate operation according to `op`.
100
+ Value createMMLA (PatternRewriter &rewriter, MMLA op, Location loc,
101
+ mlir::Type accType, Value acc, Value lhs, Value rhs) {
102
+ switch (op) {
103
+ case MMLA::Signed:
104
+ return rewriter.createOrFold <arm_neon::SmmlaOp>(loc, accType, acc, lhs,
105
+ rhs);
106
+ case MMLA::Unsigned:
107
+ return rewriter.createOrFold <arm_neon::UmmlaOp>(loc, accType, acc, lhs,
108
+ rhs);
109
+ case MMLA::Mixed:
110
+ return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
111
+ rhs);
112
+ case MMLA::MixedSwapped:
113
+ // The accumulator comes transposed and the result will be transposed
114
+ // later, so all we have to do here is swap the operands.
115
+ return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
116
+ lhs);
117
+ }
118
+ }
119
+
40
120
// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41
121
// / any vector.contract into multiple smmla instructions with unrolling so long
42
122
// / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43
123
// / = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44
124
// / necessary, a single smmla instruction is emitted.
45
- class LowerContractionToSMMLAPattern
125
+ class LowerContractionToNeonI8MMPattern
46
126
: public OpRewritePattern<vector::ContractionOp> {
47
127
public:
48
128
using OpRewritePattern::OpRewritePattern;
@@ -88,39 +168,64 @@ class LowerContractionToSMMLAPattern
88
168
return failure ();
89
169
}
90
170
91
- // Check two extsi inputs Rhs Lhs for contract.
92
- arith::ExtSIOp origLhsExtOp =
93
- dyn_cast_or_null<arith::ExtSIOp>(op.getLhs ().getDefiningOp ());
94
- arith::ExtSIOp origRhsExtOp =
95
- dyn_cast_or_null<arith::ExtSIOp>(op.getRhs ().getDefiningOp ());
96
- if (!origLhsExtOp || !origRhsExtOp) {
171
+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
172
+ // values before the extension. All four signed/unsigned combinations for
173
+ // input operands are supported, but they are lowered to different
174
+ // operations. Determine which is the appropriate operation to lower to.
175
+ MMLA mmlaOp = MMLA::Signed;
176
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
177
+ if (!maybeLhs) {
178
+ mmlaOp = MMLA::Unsigned;
179
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
180
+ }
181
+ if (!maybeLhs)
97
182
return failure ();
183
+
184
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
185
+ if (maybeRhs) {
186
+ if (mmlaOp == MMLA::Unsigned)
187
+ mmlaOp = MMLA::Mixed;
188
+ } else {
189
+ if (mmlaOp == MMLA::Signed)
190
+ mmlaOp = MMLA::MixedSwapped;
191
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
98
192
}
193
+ if (!maybeRhs)
194
+ return failure ();
195
+
196
+ Value origLhs = *maybeLhs;
197
+ Value origRhs = *maybeRhs;
99
198
100
199
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101
200
// following neon instruction. Check inputs for extsi are <=i8
102
- Value extsiLhs;
103
- Value extsiRhs;
104
- if (auto lhsExtInType =
105
- dyn_cast<mlir::VectorType>(origLhsExtOp.getIn ().getType ())) {
201
+ Value extLhs;
202
+ Value extRhs;
203
+ if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType ())) {
106
204
if (lhsExtInType.getElementTypeBitWidth () <= 8 ) {
107
205
Type targetLhsExtTy =
108
206
matchContainerType (rewriter.getI8Type (), lhsExtInType);
109
- extsiLhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetLhsExtTy,
110
- origLhsExtOp.getIn ());
207
+ if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
208
+ extLhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetLhsExtTy,
209
+ origLhs);
210
+ else
211
+ extLhs = rewriter.createOrFold <arith::ExtUIOp>(loc, targetLhsExtTy,
212
+ origLhs);
111
213
}
112
214
}
113
- if (auto rhsExtInType =
114
- dyn_cast<mlir::VectorType>(origRhsExtOp.getIn ().getType ())) {
215
+ if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType ())) {
115
216
if (rhsExtInType.getElementTypeBitWidth () <= 8 ) {
116
217
Type targetRhsExtTy =
117
218
matchContainerType (rewriter.getI8Type (), rhsExtInType);
118
- extsiRhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetRhsExtTy,
119
- origRhsExtOp.getIn ());
219
+ if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
220
+ extRhs = rewriter.createOrFold <arith::ExtUIOp>(loc, targetRhsExtTy,
221
+ origRhs);
222
+ else
223
+ extRhs = rewriter.createOrFold <arith::ExtSIOp>(loc, targetRhsExtTy,
224
+ origRhs);
120
225
}
121
226
}
122
227
123
- if (!extsiLhs || !extsiRhs ) {
228
+ if (!extLhs || !extRhs ) {
124
229
return failure ();
125
230
}
126
231
@@ -155,11 +260,11 @@ class LowerContractionToSMMLAPattern
155
260
AffineMap lhsPermutationMap = op.getIndexingMapsArray ()[0 ];
156
261
SmallVector<int64_t > lhsOffsets =
157
262
applyPermutationMap (lhsPermutationMap, ArrayRef<int64_t >(offsets));
158
- Value tiledLhs = extractOperand (extsiLhs , lhsPermutationMap, lhsOffsets);
263
+ Value tiledLhs = extractOperand (extLhs , lhsPermutationMap, lhsOffsets);
159
264
AffineMap rhsPermutationMap = op.getIndexingMapsArray ()[1 ];
160
265
SmallVector<int64_t > rhsOffsets =
161
266
applyPermutationMap (rhsPermutationMap, ArrayRef<int64_t >(offsets));
162
- Value tiledRhs = extractOperand (extsiRhs , rhsPermutationMap, rhsOffsets);
267
+ Value tiledRhs = extractOperand (extRhs , rhsPermutationMap, rhsOffsets);
163
268
AffineMap accPermutationMap = op.getIndexingMapsArray ()[2 ];
164
269
SmallVector<int64_t > accOffsets =
165
270
applyPermutationMap (accPermutationMap, ArrayRef<int64_t >(offsets));
@@ -191,6 +296,13 @@ class LowerContractionToSMMLAPattern
191
296
tiledAcc = expandForSMMLA (tiledAcc, outputExpandedType);
192
297
}
193
298
299
+ // Transpose ACC if doing signed by unsigned multiplication, because we're
300
+ // using the instruction for unsigned by signed multiplication with
301
+ // reversed operands.
302
+ if (mmlaOp == MMLA::MixedSwapped)
303
+ tiledAcc = rewriter.create <vector::TransposeOp>(
304
+ loc, tiledAcc, ArrayRef<int64_t >({1 , 0 }));
305
+
194
306
// Collapse tiled operands to 1D vectors required by smmla intrinsic
195
307
auto collapsedInputType =
196
308
VectorType::get (inputExpandedType.getNumElements (), inputElementType);
@@ -211,15 +323,21 @@ class LowerContractionToSMMLAPattern
211
323
}
212
324
213
325
// Insert contract op
214
- kAcc = rewriter.createOrFold <arm_neon::SmmlaOp>(
215
- op.getLoc (), collapsedRes.getType (), collapsedRes, collapsedLhs,
216
- collapsedRhs);
326
+ kAcc = createMMLA (rewriter, mmlaOp, op.getLoc (), collapsedRes.getType (),
327
+ collapsedRes, collapsedLhs, collapsedRhs);
217
328
218
329
// Reshape output back to 2D
219
330
Value tiledRes = rewriter.createOrFold <vector::ShapeCastOp>(
220
331
kAcc .getLoc (), tiledAcc.getType (), kAcc );
221
332
222
- // With vecmat, only one row of tiled ACC can be inserted into file result
333
+ // Because of the reversed operands the result is obtained transposed.
334
+ // Transpose it back,
335
+ if (mmlaOp == MMLA::MixedSwapped)
336
+ tiledRes = rewriter.create <vector::TransposeOp>(
337
+ loc, tiledRes, ArrayRef<int64_t >({1 , 0 }));
338
+
339
+ // With vecmat, only one row of tiled ACC can be inserted into the final
340
+ // result
223
341
if (isVecmat) {
224
342
tiledRes = rewriter.createOrFold <vector::ExtractOp>(loc, tiledRes, 0 );
225
343
}
@@ -239,8 +357,8 @@ class LowerContractionToSMMLAPattern
239
357
240
358
} // namespace
241
359
242
- void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns (
360
+ void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns (
243
361
RewritePatternSet &patterns) {
244
362
MLIRContext *context = patterns.getContext ();
245
- patterns.add <LowerContractionToSMMLAPattern >(context, /* benefit=*/ 2 );
363
+ patterns.add <LowerContractionToNeonI8MMPattern >(context, /* benefit=*/ 2 );
246
364
}
0 commit comments