Skip to content

Commit 80a3dfd

Browse files
[ONNX] Add support for Onnx.QLinearMul op (#4159)
This commit adds the lowering for the [Onnx.QLinearMul](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqlinearmul) op. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent f659a0b commit 80a3dfd

File tree

4 files changed

+125
-0
lines changed

4 files changed

+125
-0
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ LogicalResult extractPerTensorQuantizationArguments(
125125
ConversionPatternRewriter &rewriter, Location loc, Value inScale,
126126
Value inZeroPoint, Value &outScale, Value &outZeroPoint);
127127

128+
/// This utility takes as input a quantized tensor and dequantizes it.
129+
LogicalResult createDequantizeTensor(ConversionPatternRewriter &rewriter,
130+
Location loc, Value input, Value scale,
131+
Value zeroPoint, Value &output);
132+
128133
} // namespace mlir::torch::onnx_c
129134

130135
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,4 +1077,73 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
10771077
binder.op, resultType, transposedLhs, transposedRhs);
10781078
return success();
10791079
});
1080+
patterns.onOp(
1081+
"QLinearMul", 1,
1082+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1083+
Location loc = binder.getLoc();
1084+
Torch::ValueTensorType resultType;
1085+
llvm::SmallVector<Value> operands;
1086+
if (binder.tensorOperandsList(operands) ||
1087+
binder.tensorResultType(resultType))
1088+
return failure();
1089+
1090+
if (operands.size() != 8)
1091+
return rewriter.notifyMatchFailure(
1092+
binder.op, "Unimplemented: expected 8 input operands");
1093+
1094+
Value a, b, aScale, aZp, bScale, bZp, cScale, cZp;
1095+
1096+
if (failed(extractPerTensorQuantizationArguments(
1097+
rewriter, loc, /*scale=*/operands[1],
1098+
/*zero_point=*/operands[2], aScale, aZp)))
1099+
return rewriter.notifyMatchFailure(
1100+
binder.op, "Incompatible arguments for per-tensor quantization");
1101+
1102+
if (failed(extractPerTensorQuantizationArguments(
1103+
rewriter, loc, /*scale=*/operands[4],
1104+
/*zero_point=*/operands[5], bScale, bZp)))
1105+
return rewriter.notifyMatchFailure(
1106+
binder.op, "Incompatible arguments for per-tensor quantization");
1107+
1108+
if (failed(extractPerTensorQuantizationArguments(
1109+
rewriter, loc, /*scale=*/operands[6],
1110+
/*zero_point=*/operands[7], cScale, cZp)))
1111+
return rewriter.notifyMatchFailure(
1112+
binder.op, "Incompatible arguments for per-tensor quantization");
1113+
1114+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[0],
1115+
/*scale=*/aScale, /*zero_point=*/aZp,
1116+
/*output=*/a)))
1117+
return rewriter.notifyMatchFailure(
1118+
binder.op, "Failed to dequantize the input tensor `a` because of "
1119+
"missing sizes");
1120+
1121+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[3],
1122+
/*scale=*/bScale, /*zero_point=*/bZp,
1123+
/*output=*/b)))
1124+
return rewriter.notifyMatchFailure(
1125+
binder.op, "Failed to dequantize the input tensor `b` because of "
1126+
"missing sizes");
1127+
1128+
// Computing the Mul result.
1129+
auto cTy = rewriter.getType<Torch::ValueTensorType>(
1130+
resultType.getOptionalSizes(), rewriter.getF32Type());
1131+
Value c =
1132+
rewriter.create<Torch::AtenMulTensorOp>(binder.getLoc(), cTy, a, b);
1133+
1134+
// Quantizing the result of Mul operation.
1135+
cTy = dyn_cast<Torch::ValueTensorType>(
1136+
getQTorchTypeFromTorchIntType(resultType));
1137+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
1138+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1139+
rewriter.getIntegerAttr(
1140+
rewriter.getIntegerType(64),
1141+
static_cast<int64_t>(
1142+
Torch::getScalarTypeForType(cTy.getDtype()))));
1143+
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
1144+
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
1145+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
1146+
c);
1147+
return success();
1148+
});
10801149
}

lib/Conversion/TorchOnnxToTorch/Utils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,23 @@ LogicalResult mlir::torch::onnx_c::extractPerTensorQuantizationArguments(
182182

183183
return success();
184184
}
185+
186+
LogicalResult mlir::torch::onnx_c::createDequantizeTensor(
187+
ConversionPatternRewriter &rewriter, Location loc, Value input, Value scale,
188+
Value zeroPoint, Value &output) {
189+
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
190+
if (!inputTy || !inputTy.hasSizes())
191+
return failure();
192+
193+
Torch::ValueTensorType makeTensorTy = getQTorchTypeFromTorchIntType(inputTy);
194+
Value quantizedInput =
195+
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
196+
loc, makeTensorTy, input, scale, zeroPoint);
197+
198+
Torch::ValueTensorType resultTy = rewriter.getType<Torch::ValueTensorType>(
199+
inputTy.getSizes(), rewriter.getF32Type());
200+
output = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, resultTy,
201+
quantizedInput);
202+
203+
return success();
204+
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3849,3 +3849,34 @@ func.func @test_qlinearAveragePool(%arg0: !torch.vtensor<[1,128,56,56],ui8>, %ar
38493849
// CHECK: return %[[OUT]]
38503850
return %0 : !torch.vtensor<[1,128,28,28],ui8>
38513851
}
3852+
3853+
// -----
3854+
3855+
// CHECK-LABEL: @test_qlinearmul(
3856+
// CHECK-SAME: %[[A:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,4096],ui8>,
3857+
// CHECK-SAME: %[[A_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3858+
// CHECK-SAME: %[[A_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3859+
// CHECK-SAME: %[[B:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,4096],ui8>,
3860+
// CHECK-SAME: %[[B_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3861+
// CHECK-SAME: %[[B_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3862+
// CHECK-SAME: %[[C_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3863+
// CHECK-SAME: %[[C_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8>
3864+
func.func @test_qlinearmul(%arg0: !torch.vtensor<[1,4096],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,4096],ui8>, %arg4: !torch.vtensor<[],f32>, %arg5: !torch.vtensor<[],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8> attributes {torch.onnx_meta.opset_version = 10 : si64} {
3865+
%0 = torch.operator "onnx.QLinearMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,4096],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,4096],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8>
3866+
// CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[A_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3867+
// CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[A_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3868+
// CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[B_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3869+
// CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[B_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3870+
// CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[C_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3871+
// CHECK-DAG: %[[CSCALE:.+]] = torch.aten.item %[[C_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3872+
// CHECK-DAG: %[[A_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[A]], %[[ASCALE]], %[[AZP]] : !torch.vtensor<[1,4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3873+
// CHECK: %[[A_F32:.+]] = torch.aten.dequantize.self %[[A_QUANT]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],f32>
3874+
// CHECK-DAG: %[[B_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[B]], %[[BSCALE]], %[[BZP]] : !torch.vtensor<[1,4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3875+
// CHECK: %[[B_F32:.+]] = torch.aten.dequantize.self %[[B_QUANT]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],f32>
3876+
// CHECK: %[[ADD:.+]] = torch.aten.mul.Tensor %[[A_F32]], %[[B_F32]] : !torch.vtensor<[1,4096],f32>, !torch.vtensor<[1,4096],f32> -> !torch.vtensor<[1,4096],f32>
3877+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3878+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[ADD]], %[[CSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[1,4096],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3879+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],ui8>
3880+
// CHECK: return %[[OUT]]
3881+
return %0 : !torch.vtensor<[1,4096],ui8>
3882+
}

0 commit comments

Comments
 (0)