Skip to content

Commit b31413a

Browse files
[MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) (#144909)
Just recently learned about `isSignlessInteger`, use that instead of comparing to types obtained via `rewriter.getI<N>Type()`. It also makes it closer to a similar function in `LowerContractionToNeonI8MMPattern.cpp` (formerly `LowerContractionToSMMLAPattern.cpp`) which would help a potential effort to unify these patterns.
1 parent 4af96a9 commit b31413a

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace {
3939
//
4040
// Return success only for extensions from `i8` to `i32`.
4141
template <typename Op>
42-
std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
42+
std::optional<Value> getExtOperand(Value v) {
4343

4444
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
4545
"Must be instantiated with either sign- or zero- extension op");
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
5050
if (!extOp) {
5151
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
5252
auto vTy = cast<VectorType>(v.getType());
53-
if (vTy.getElementType() != i8Ty)
53+
if (!vTy.getElementType().isSignlessInteger(8))
5454
return {};
5555
return v;
5656
}
@@ -61,11 +61,11 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
6161
// operation type, check it's extended from `i8` to `i32`.
6262
auto inOp = extOp.getIn();
6363
auto inTy = dyn_cast<VectorType>(inOp.getType());
64-
if (!inTy || inTy.getElementType() != i8Ty)
64+
if (!inTy || !inTy.getElementType().isSignlessInteger(8))
6565
return {};
6666

6767
auto outTy = dyn_cast<VectorType>(extOp.getType());
68-
if (!outTy || outTy.getElementType() != i32Ty)
68+
if (!outTy || !outTy.getElementType().isSignlessInteger(32))
6969
return {};
7070

7171
return inOp;
@@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern
199199
// operands are supported, but they are lowered to different operations.
200200
// Determine which is the appropriate operation to lower to.
201201
MMLA mmlaOp = MMLA::Signed;
202-
auto maybeLhs = getExtOperand<arith::ExtSIOp>(
203-
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
202+
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
204203
if (!maybeLhs) {
205204
mmlaOp = MMLA::Unsigned;
206-
maybeLhs = getExtOperand<arith::ExtUIOp>(
207-
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
205+
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
208206
}
209207
if (!maybeLhs)
210208
return rewriter.notifyMatchFailure(
211209
op, "LHS is not a sign- or zero- extended i8");
212210

213-
auto maybeRhs = getExtOperand<arith::ExtSIOp>(
214-
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
211+
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
215212
if (maybeRhs) {
216213
if (mmlaOp == MMLA::Unsigned)
217214
mmlaOp = MMLA::Mixed;
218215
} else {
219216
if (mmlaOp == MMLA::Signed)
220217
mmlaOp = MMLA::MixedSwapped;
221-
maybeRhs = getExtOperand<arith::ExtUIOp>(
222-
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
218+
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
223219
}
224220
if (!maybeRhs)
225221
return rewriter.notifyMatchFailure(

0 commit comments

Comments
 (0)