@@ -39,7 +39,7 @@ namespace {
39
39
//
40
40
// Return success only for extensions from `i8` to `i32`.
41
41
template <typename Op>
42
- std::optional<Value> getExtOperand (Value v, Type i8Ty, Type i32Ty ) {
42
+ std::optional<Value> getExtOperand (Value v) {
43
43
44
44
static_assert (llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
45
45
" Must be instantiated with either sign- or zero- extension op" );
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
50
50
if (!extOp) {
51
51
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
52
52
auto vTy = cast<VectorType>(v.getType ());
53
- if (vTy.getElementType () != i8Ty )
53
+ if (! vTy.getElementType (). isSignlessInteger ( 8 ) )
54
54
return {};
55
55
return v;
56
56
}
@@ -61,11 +61,11 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
61
61
// operation type, check it's extended from `i8` to `i32`.
62
62
auto inOp = extOp.getIn ();
63
63
auto inTy = dyn_cast<VectorType>(inOp.getType ());
64
- if (!inTy || inTy.getElementType () != i8Ty )
64
+ if (!inTy || ! inTy.getElementType (). isSignlessInteger ( 8 ) )
65
65
return {};
66
66
67
67
auto outTy = dyn_cast<VectorType>(extOp.getType ());
68
- if (!outTy || outTy.getElementType () != i32Ty )
68
+ if (!outTy || ! outTy.getElementType (). isSignlessInteger ( 32 ) )
69
69
return {};
70
70
71
71
return inOp;
@@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern
199
199
// operands are supported, but they are lowered to different operations.
200
200
// Determine which is the appropriate operation to lower to.
201
201
MMLA mmlaOp = MMLA::Signed;
202
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(
203
- op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
202
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
204
203
if (!maybeLhs) {
205
204
mmlaOp = MMLA::Unsigned;
206
- maybeLhs = getExtOperand<arith::ExtUIOp>(
207
- op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
205
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
208
206
}
209
207
if (!maybeLhs)
210
208
return rewriter.notifyMatchFailure (
211
209
op, " LHS is not a sign- or zero- extended i8" );
212
210
213
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(
214
- op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
211
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
215
212
if (maybeRhs) {
216
213
if (mmlaOp == MMLA::Unsigned)
217
214
mmlaOp = MMLA::Mixed;
218
215
} else {
219
216
if (mmlaOp == MMLA::Signed)
220
217
mmlaOp = MMLA::MixedSwapped;
221
- maybeRhs = getExtOperand<arith::ExtUIOp>(
222
- op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
218
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
223
219
}
224
220
if (!maybeRhs)
225
221
return rewriter.notifyMatchFailure (
0 commit comments