Skip to content

Commit 3a85fa8

Browse files
[ONNX] Add support for Onnx.FusedMatMul op (#4147)
This commit adds the Onnx->Torch lowering for [Onnx.FusedMatMul](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul) op. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 3dcf188 commit 3a85fa8

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,4 +1009,60 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
10091009
averagePool);
10101010
return success();
10111011
});
1012+
patterns.onOp(
1013+
"FusedMatMul", 1,
1014+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1015+
Torch::ValueTensorType resultType;
1016+
Value lhs, rhs;
1017+
int64_t transA, transB, transBatchA, transBatchB;
1018+
if (binder.tensorOperands(lhs, rhs) ||
1019+
binder.s64IntegerAttr(transA, "transA", 0) ||
1020+
binder.s64IntegerAttr(transB, "transB", 0) ||
1021+
binder.s64IntegerAttr(transBatchA, "transBatchA", 0) ||
1022+
binder.s64IntegerAttr(transBatchB, "transBatchB", 0) ||
1023+
binder.tensorResultType(resultType))
1024+
return failure();
1025+
1026+
// Transposing the LHS argument.
1027+
Value transposedLhs = lhs;
1028+
if (transA) {
1029+
// Determine the rank of lhs tensor.
1030+
std::optional<unsigned> maybeRank = Torch::getTensorRank(lhs);
1031+
if (!maybeRank)
1032+
return rewriter.notifyMatchFailure(
1033+
binder.op, "Unimplemented: unranked lhs tensor");
1034+
unsigned lhsRank = *maybeRank;
1035+
if (failed(createTorchTransposeOp(
1036+
rewriter, binder.getLoc(), lhs,
1037+
/*dimA=*/lhsRank - 2, /*dimB=*/lhsRank - 1, transposedLhs)))
1038+
return rewriter.notifyMatchFailure(
1039+
binder.op, "Failed to create TorchTranspose op for lhs");
1040+
}
1041+
1042+
// Transposing the RHS argument.
1043+
Value transposedRhs = rhs;
1044+
if (transB) {
1045+
std::optional<unsigned> maybeRank = Torch::getTensorRank(rhs);
1046+
if (!maybeRank)
1047+
return rewriter.notifyMatchFailure(
1048+
binder.op, "Unimplemented: unranked rhs tensor");
1049+
unsigned rhsRank = *maybeRank;
1050+
if (failed(createTorchTransposeOp(
1051+
rewriter, binder.getLoc(), rhs,
1052+
/*dimA=*/rhsRank - 2, /*dimB=*/rhsRank - 1, transposedRhs)))
1053+
return rewriter.notifyMatchFailure(
1054+
binder.op, "Failed to create TorchTranspose op for rhs");
1055+
}
1056+
1057+
// TODO: Add support for `transBatchA` and `transBatchB`
1058+
// attribute.
1059+
if (transBatchA || transBatchB)
1060+
return rewriter.notifyMatchFailure(
1061+
binder.op, "Unimplemented: support not present for "
1062+
"transBatchA and transBatchB attribute");
1063+
1064+
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(
1065+
binder.op, resultType, transposedLhs, transposedRhs);
1066+
return success();
1067+
});
10121068
}

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2921,3 +2921,17 @@ func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !t
29212921
%0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) {torch.onnx.inverse = 1 : si64} : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32>
29222922
return %0 : !torch.vtensor<[10,10,2],f32>
29232923
}
2924+
2925+
// -----
2926+
2927+
// CHECK-LABEL: @test_fusedMatmul(
2928+
// CHECK-SAME: %[[LHS:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,12,256,64],f32>,
2929+
// CHECK-SAME: %[[RHS:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,12,256,64],f32>) -> !torch.vtensor<[?,12,256,256],f32>
2930+
func.func @test_fusedMatmul(%arg0: !torch.vtensor<[?,12,256,64],f32>, %arg1: !torch.vtensor<[?,12,256,64],f32>) -> !torch.vtensor<[?,12,256,256],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {com.microsoft = 1 : si64}} {
2931+
%0 = torch.operator "onnx.FusedMatMul"(%arg0, %arg1) {torch.onnx.alpha = 1.250000e-01 : f32, torch.onnx.transA = 0 : si64, torch.onnx.transB = 1 : si64} : (!torch.vtensor<[?,12,256,64],f32>, !torch.vtensor<[?,12,256,64],f32>) -> !torch.vtensor<[?,12,256,256],f32>
2932+
// CHECK: %[[DIMA:.*]] = torch.constant.int 2
2933+
// CHECK: %[[DIMB:.*]] = torch.constant.int 3
2934+
// CHECK: %[[TRANSPOSED_RHS:.*]] = torch.aten.transpose.int %[[RHS]], %[[DIMA]], %[[DIMB]] : !torch.vtensor<[?,12,256,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12,64,256],f32>
2935+
// CHECK: torch.aten.matmul %[[LHS]], %[[TRANSPOSED_RHS]] : !torch.vtensor<[?,12,256,64],f32>, !torch.vtensor<[?,12,64,256],f32> -> !torch.vtensor<[?,12,256,256],f32>
2936+
return %0 : !torch.vtensor<[?,12,256,256],f32>
2937+
}

0 commit comments

Comments
 (0)